{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import scipy.stats as stats\n",
    "import random\n",
    "import math\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据大小: (506, 14)\n"
     ]
    }
   ],
   "source": [
    "boston_houseprice_data = pd.read_csv(\"data/housing.data\",header=0,index_col=None,sep='\\s+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据大小:  \n",
      " (506, 14)\n",
      "数据示例： \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CRIM</th>\n",
       "      <th>ZN</th>\n",
       "      <th>INDUS</th>\n",
       "      <th>CHAS</th>\n",
       "      <th>NOX</th>\n",
       "      <th>RM</th>\n",
       "      <th>AGE</th>\n",
       "      <th>DIS</th>\n",
       "      <th>RAD</th>\n",
       "      <th>TAX</th>\n",
       "      <th>PTRATIO</th>\n",
       "      <th>B</th>\n",
       "      <th>LSTAT</th>\n",
       "      <th>MEDV</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.00632</td>\n",
       "      <td>18.0</td>\n",
       "      <td>2.31</td>\n",
       "      <td>0</td>\n",
       "      <td>0.538</td>\n",
       "      <td>6.575</td>\n",
       "      <td>65.2</td>\n",
       "      <td>4.0900</td>\n",
       "      <td>1</td>\n",
       "      <td>296.0</td>\n",
       "      <td>15.3</td>\n",
       "      <td>396.90</td>\n",
       "      <td>4.98</td>\n",
       "      <td>24.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.02731</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>6.421</td>\n",
       "      <td>78.9</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2</td>\n",
       "      <td>242.0</td>\n",
       "      <td>17.8</td>\n",
       "      <td>396.90</td>\n",
       "      <td>9.14</td>\n",
       "      <td>21.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.02729</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>7.185</td>\n",
       "      <td>61.1</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2</td>\n",
       "      <td>242.0</td>\n",
       "      <td>17.8</td>\n",
       "      <td>392.83</td>\n",
       "      <td>4.03</td>\n",
       "      <td>34.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.03237</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.18</td>\n",
       "      <td>0</td>\n",
       "      <td>0.458</td>\n",
       "      <td>6.998</td>\n",
       "      <td>45.8</td>\n",
       "      <td>6.0622</td>\n",
       "      <td>3</td>\n",
       "      <td>222.0</td>\n",
       "      <td>18.7</td>\n",
       "      <td>394.63</td>\n",
       "      <td>2.94</td>\n",
       "      <td>33.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.06905</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.18</td>\n",
       "      <td>0</td>\n",
       "      <td>0.458</td>\n",
       "      <td>7.147</td>\n",
       "      <td>54.2</td>\n",
       "      <td>6.0622</td>\n",
       "      <td>3</td>\n",
       "      <td>222.0</td>\n",
       "      <td>18.7</td>\n",
       "      <td>396.90</td>\n",
       "      <td>5.33</td>\n",
       "      <td>36.2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      CRIM    ZN  INDUS  CHAS    NOX     RM   AGE     DIS  RAD    TAX  \\\n",
       "0  0.00632  18.0   2.31     0  0.538  6.575  65.2  4.0900    1  296.0   \n",
       "1  0.02731   0.0   7.07     0  0.469  6.421  78.9  4.9671    2  242.0   \n",
       "2  0.02729   0.0   7.07     0  0.469  7.185  61.1  4.9671    2  242.0   \n",
       "3  0.03237   0.0   2.18     0  0.458  6.998  45.8  6.0622    3  222.0   \n",
       "4  0.06905   0.0   2.18     0  0.458  7.147  54.2  6.0622    3  222.0   \n",
       "\n",
       "   PTRATIO       B  LSTAT  MEDV  \n",
       "0     15.3  396.90   4.98  24.0  \n",
       "1     17.8  396.90   9.14  21.6  \n",
       "2     17.8  392.83   4.03  34.7  \n",
       "3     18.7  394.63   2.94  33.4  \n",
       "4     18.7  396.90   5.33  36.2  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"数据大小:  \\n\", boston_houseprice_data.shape)\n",
    "print(\"数据示例： \")\n",
    "boston_houseprice_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_sample = boston_houseprice_data.iloc[:, :-1].values\n",
    "data_label = boston_houseprice_data.iloc[:, -1].values.reshape(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = data_sample.mean(axis=0)\n",
    "std = data_sample.std(axis=0)\n",
    "data_sample = (data_sample-mean)/std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_label_length: 404\n"
     ]
    }
   ],
   "source": [
    "data_length = data_label.shape[0]\n",
    "train_data_length = int(data_length * 0.8)\n",
    "print(\"train_label_length:\",train_data_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_sample_train, data_sample_test = data_sample[:train_data_length], data_sample[train_data_length:]\n",
    "data_label_train, data_label_test = data_label[:train_data_length], data_label[train_data_length:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations = 1000\n",
    "lr = 0.001\n",
    "weight_decay = 0.01\n",
    "train_batch_size = 16\n",
    "test_batch_size = 100\n",
    "max_loss = math.inf\n",
    "early_stopping_iter = 15\n",
    "early_stopping_mark = 0\n",
    "train_error = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_handler = DataHander(16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = Optimizer(lr = lr,momentum = 0.9,iteration = 0,gamma = 0.0005,power = 0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "initializer = Initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_handler.get_data(sample=data_sample_train,label=data_label_train)\n",
    "data_handler.shuffle()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "dnn = DNNNet(optimizer = opt.batch_gradient_descent_anneling, initializer = initializer.xavier, batch_size = train_batch_size, weights_decay = weight_decay)\n",
    "dnn.initial()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第 0 次迭代\n",
      "predict: \n",
      "  [[-2.17855006]\n",
      " [-0.06830612]\n",
      " [-1.88400996]\n",
      " [ 0.15689916]\n",
      " [-1.5764663 ]\n",
      " [-1.92708457]\n",
      " [ 0.20079785]\n",
      " [ 0.0439734 ]\n",
      " [ 0.48260857]\n",
      " [-0.49813407]\n",
      " [ 0.35426702]\n",
      " [-0.01529609]\n",
      " [ 0.77557443]\n",
      " [ 0.63569874]\n",
      " [ 0.1770889 ]\n",
      " [-0.1475898 ]]\n",
      "label: \n",
      " [[10.5]\n",
      " [22.8]\n",
      " [ 9.7]\n",
      " [20.3]\n",
      " [22.7]\n",
      " [10.2]\n",
      " [20.8]\n",
      " [25.3]\n",
      " [31.2]\n",
      " [27. ]\n",
      " [19.7]\n",
      " [19.3]\n",
      " [21.9]\n",
      " [24.8]\n",
      " [17.1]\n",
      " [36.2]]\n",
      "loss:  506.68602730198205\n",
      "第 1 次迭代\n",
      "predict: \n",
      "  [[-0.56815157]\n",
      " [ 0.28502612]\n",
      " [-0.04192595]\n",
      " [ 0.63710978]\n",
      " [-0.0362799 ]\n",
      " [-0.39731007]\n",
      " [ 0.19094034]\n",
      " [-0.0150731 ]\n",
      " [-0.02125388]\n",
      " [-0.04385016]\n",
      " [-0.05119533]\n",
      " [ 0.67666849]\n",
      " [-0.41764421]\n",
      " [ 0.2100289 ]\n",
      " [-1.12957729]\n",
      " [-0.04962692]]\n",
      "label: \n",
      " [[14. ]\n",
      " [50. ]\n",
      " [19. ]\n",
      " [18.2]\n",
      " [19.9]\n",
      " [20.5]\n",
      " [21.1]\n",
      " [21.9]\n",
      " [21.7]\n",
      " [23.6]\n",
      " [18.4]\n",
      " [22. ]\n",
      " [20.3]\n",
      " [39.8]\n",
      " [23.3]\n",
      " [24. ]]\n",
      "loss:  631.5139215340353\n",
      "第 2 次迭代\n",
      "predict: \n",
      "  [[-1.24578309]\n",
      " [-0.2409983 ]\n",
      " [-0.90125377]\n",
      " [ 0.11160639]\n",
      " [ 0.19871666]\n",
      " [ 0.0111124 ]\n",
      " [ 0.87254843]\n",
      " [ 0.14586281]\n",
      " [ 0.16105556]\n",
      " [ 0.67197252]\n",
      " [ 0.72825802]\n",
      " [-1.21884464]\n",
      " [-0.40271881]\n",
      " [ 0.90204264]\n",
      " [ 0.9191594 ]\n",
      " [ 0.792727  ]]\n",
      "label: \n",
      " [[ 6.3]\n",
      " [17.4]\n",
      " [17. ]\n",
      " [18.5]\n",
      " [48.8]\n",
      " [19.4]\n",
      " [15. ]\n",
      " [15.2]\n",
      " [22.6]\n",
      " [17.2]\n",
      " [17.5]\n",
      " [13.3]\n",
      " [18.1]\n",
      " [32.7]\n",
      " [23.2]\n",
      " [22. ]]\n",
      "loss:  484.9607680513378\n",
      "第 3 次迭代\n",
      "predict: \n",
      "  [[ 1.09947453]\n",
      " [ 0.21586492]\n",
      " [ 1.05191557]\n",
      " [ 0.205219  ]\n",
      " [ 0.18717223]\n",
      " [ 0.98183719]\n",
      " [ 1.69036714]\n",
      " [ 0.38869292]\n",
      " [-1.0767959 ]\n",
      " [ 0.13075649]\n",
      " [ 1.28563806]\n",
      " [ 0.03151431]\n",
      " [ 0.62443023]\n",
      " [ 0.95274995]\n",
      " [-1.11075444]\n",
      " [ 1.14361791]]\n",
      "label: \n",
      " [[28.6]\n",
      " [19.2]\n",
      " [34.7]\n",
      " [19.4]\n",
      " [19.5]\n",
      " [33.4]\n",
      " [46. ]\n",
      " [22.8]\n",
      " [13.1]\n",
      " [29.4]\n",
      " [29.6]\n",
      " [13.9]\n",
      " [28.2]\n",
      " [50. ]\n",
      " [12.1]\n",
      " [27.1]]\n",
      "loss:  785.8033167613589\n",
      "第 4 次迭代\n",
      "predict: \n",
      "  [[ 0.37022767]\n",
      " [ 1.50570241]\n",
      " [ 0.34379985]\n",
      " [ 0.93965599]\n",
      " [ 0.7589821 ]\n",
      " [ 1.06457954]\n",
      " [ 0.84647746]\n",
      " [ 1.27869432]\n",
      " [-1.5896347 ]\n",
      " [ 1.41026414]\n",
      " [ 1.15973592]\n",
      " [ 0.90093415]\n",
      " [ 0.29509539]\n",
      " [ 0.45848011]\n",
      " [-0.8852695 ]\n",
      " [-0.59539079]]\n",
      "label: \n",
      " [[18.3]\n",
      " [41.7]\n",
      " [33.8]\n",
      " [28.7]\n",
      " [22.2]\n",
      " [17.6]\n",
      " [28.4]\n",
      " [24.4]\n",
      " [23.1]\n",
      " [48.3]\n",
      " [19.4]\n",
      " [19.4]\n",
      " [30.1]\n",
      " [21.7]\n",
      " [21.7]\n",
      " [50. ]]\n",
      "loss:  859.1632260677934\n",
      "第 5 次迭代\n",
      "predict: \n",
      "  [[ 1.65496523]\n",
      " [ 0.96039416]\n",
      " [ 0.61355325]\n",
      " [-0.59091719]\n",
      " [ 1.63890535]\n",
      " [ 1.52907012]\n",
      " [ 1.29049762]\n",
      " [ 0.52048607]\n",
      " [ 0.66045688]\n",
      " [ 0.83746837]\n",
      " [ 0.85819842]\n",
      " [ 2.71611989]\n",
      " [ 1.8056577 ]\n",
      " [-0.73164998]\n",
      " [-0.7394702 ]\n",
      " [ 2.12481603]]\n",
      "label: \n",
      " [[23.4]\n",
      " [23.1]\n",
      " [21.2]\n",
      " [16.8]\n",
      " [28. ]\n",
      " [21.2]\n",
      " [21.7]\n",
      " [20.4]\n",
      " [18.8]\n",
      " [24.3]\n",
      " [24. ]\n",
      " [33. ]\n",
      " [30.8]\n",
      " [22.6]\n",
      " [11.3]\n",
      " [34.9]]\n",
      "loss:  532.4524513639537\n",
      "第 6 次迭代\n",
      "predict: \n",
      "  [[ 3.06434478]\n",
      " [ 1.71931355]\n",
      " [ 2.18739952]\n",
      " [ 1.45996977]\n",
      " [ 1.88894523]\n",
      " [ 1.95691614]\n",
      " [ 1.7688486 ]\n",
      " [ 0.86748172]\n",
      " [ 1.3332557 ]\n",
      " [ 1.60724242]\n",
      " [-0.55081543]\n",
      " [ 3.403988  ]\n",
      " [ 3.09044937]\n",
      " [ 2.50681769]\n",
      " [ 1.50704176]\n",
      " [ 0.18546078]]\n",
      "label: \n",
      " [[38.7]\n",
      " [18.5]\n",
      " [20.3]\n",
      " [20.6]\n",
      " [32. ]\n",
      " [22.8]\n",
      " [16. ]\n",
      " [16.1]\n",
      " [16.5]\n",
      " [36. ]\n",
      " [ 7.2]\n",
      " [42.8]\n",
      " [29. ]\n",
      " [50. ]\n",
      " [20.4]\n",
      " [11.8]]\n",
      "loss:  654.0828025622525\n",
      "第 7 次迭代\n",
      "predict: \n",
      "  [[ 1.67079874]\n",
      " [ 4.29189021]\n",
      " [-0.16068646]\n",
      " [ 3.0644344 ]\n",
      " [ 1.53813254]\n",
      " [ 0.60928006]\n",
      " [ 0.67306082]\n",
      " [ 3.43698574]\n",
      " [ 1.46381188]\n",
      " [ 1.46408308]\n",
      " [ 3.33683888]\n",
      " [ 3.82665476]\n",
      " [ 3.62475149]\n",
      " [ 1.76877274]\n",
      " [ 1.73006897]\n",
      " [ 1.29844473]]\n",
      "label: \n",
      " [[18.2]\n",
      " [31.6]\n",
      " [11.5]\n",
      " [34.9]\n",
      " [23.1]\n",
      " [15.4]\n",
      " [13.8]\n",
      " [20.9]\n",
      " [19.8]\n",
      " [22.1]\n",
      " [26.2]\n",
      " [32.9]\n",
      " [36.4]\n",
      " [18.9]\n",
      " [20.4]\n",
      " [14.8]]\n",
      "loss:  459.87310673344336\n",
      "第 8 次迭代\n",
      "predict: \n",
      "  [[ 2.1969742 ]\n",
      " [ 1.59954002]\n",
      " [ 1.37819997]\n",
      " [ 2.93592023]\n",
      " [ 4.48746946]\n",
      " [ 1.4687194 ]\n",
      " [ 5.85761537]\n",
      " [ 2.37879234]\n",
      " [ 0.06205967]\n",
      " [ 5.16739039]\n",
      " [-0.65838315]\n",
      " [ 4.44722072]\n",
      " [ 1.15576116]\n",
      " [ 3.17074037]\n",
      " [ 1.8735698 ]\n",
      " [ 4.17980724]]\n",
      "label: \n",
      " [[19.3]\n",
      " [25. ]\n",
      " [19.6]\n",
      " [24.4]\n",
      " [28.7]\n",
      " [23.8]\n",
      " [33.3]\n",
      " [21. ]\n",
      " [15. ]\n",
      " [33.2]\n",
      " [10.4]\n",
      " [50. ]\n",
      " [15.6]\n",
      " [21.1]\n",
      " [13.6]\n",
      " [31.5]]\n",
      "loss:  527.3694748631779\n",
      "第 9 次迭代\n",
      "predict: \n",
      "  [[6.23664896]\n",
      " [0.50606548]\n",
      " [6.02595742]\n",
      " [3.94369022]\n",
      " [0.47022228]\n",
      " [1.37000503]\n",
      " [0.56306839]\n",
      " [6.29037549]\n",
      " [1.41494955]\n",
      " [6.09040461]\n",
      " [6.53724877]\n",
      " [7.44407006]\n",
      " [4.71741976]\n",
      " [0.46511104]\n",
      " [0.89089262]\n",
      " [6.30951345]]\n",
      "label: \n",
      " [[22.9]\n",
      " [15.1]\n",
      " [20.6]\n",
      " [18.7]\n",
      " [ 8.8]\n",
      " [15.3]\n",
      " [12.3]\n",
      " [32.4]\n",
      " [15.6]\n",
      " [23.1]\n",
      " [26.6]\n",
      " [45.4]\n",
      " [20.9]\n",
      " [12.7]\n",
      " [27.5]\n",
      " [18.6]]\n",
      "loss:  350.50482892657783\n",
      "第 10 次迭代\n",
      "predict: \n",
      "  [[4.67821641]\n",
      " [5.19262711]\n",
      " [2.38821171]\n",
      " [0.76477764]\n",
      " [6.92670823]\n",
      " [4.74344792]\n",
      " [2.5040266 ]\n",
      " [5.03263932]\n",
      " [7.6908715 ]\n",
      " [5.80181434]\n",
      " [1.2029961 ]\n",
      " [1.86864539]\n",
      " [5.16910517]\n",
      " [6.73624442]\n",
      " [8.79662839]\n",
      " [5.51201951]]\n",
      "label: \n",
      " [[21.4]\n",
      " [36.5]\n",
      " [22.7]\n",
      " [13.1]\n",
      " [24.5]\n",
      " [18.7]\n",
      " [18.5]\n",
      " [22.9]\n",
      " [44.8]\n",
      " [50. ]\n",
      " [12.5]\n",
      " [13.4]\n",
      " [18.9]\n",
      " [26.6]\n",
      " [22.5]\n",
      " [20.7]]\n",
      "loss:  468.4962902211165\n",
      "第 11 次迭代\n",
      "predict: \n",
      "  [[ 9.09530612]\n",
      " [ 7.56197964]\n",
      " [ 5.47456075]\n",
      " [ 4.04290155]\n",
      " [ 7.58823759]\n",
      " [ 6.96752157]\n",
      " [10.08011917]\n",
      " [ 9.69703045]\n",
      " [ 6.6863498 ]\n",
      " [ 8.12735097]\n",
      " [ 6.83659152]\n",
      " [ 7.65455872]\n",
      " [ 3.71783081]\n",
      " [ 1.26462557]\n",
      " [ 2.45354835]\n",
      " [ 8.89251972]]\n",
      "label: \n",
      " [[22.2]\n",
      " [31.6]\n",
      " [20.6]\n",
      " [22.5]\n",
      " [19.8]\n",
      " [31.5]\n",
      " [26.6]\n",
      " [24.1]\n",
      " [21.4]\n",
      " [32.5]\n",
      " [25. ]\n",
      " [43.1]\n",
      " [22.8]\n",
      " [13.9]\n",
      " [18.4]\n",
      " [33.4]]\n",
      "loss:  395.3636543300633\n",
      "第 12 次迭代\n",
      "predict: \n",
      "  [[16.7617491 ]\n",
      " [17.57161939]\n",
      " [ 6.84659903]\n",
      " [17.81339092]\n",
      " [ 7.18337986]\n",
      " [ 4.73893588]\n",
      " [ 3.26443493]\n",
      " [ 1.64136906]\n",
      " [11.49724891]\n",
      " [ 8.29670175]\n",
      " [ 9.65016353]\n",
      " [ 2.78410841]\n",
      " [11.30992565]\n",
      " [13.03049009]\n",
      " [ 2.96977445]\n",
      " [ 1.66718309]]\n",
      "label: \n",
      " [[35.4]\n",
      " [50. ]\n",
      " [21.2]\n",
      " [50. ]\n",
      " [22.2]\n",
      " [14.4]\n",
      " [19.5]\n",
      " [21.9]\n",
      " [29.8]\n",
      " [29.6]\n",
      " [24.4]\n",
      " [13.3]\n",
      " [24.7]\n",
      " [28.5]\n",
      " [23. ]\n",
      " [ 8.3]]\n",
      "loss:  350.9118636628382\n",
      "第 13 次迭代\n",
      "predict: \n",
      "  [[13.17332866]\n",
      " [17.88778939]\n",
      " [18.51088642]\n",
      " [ 4.96253577]\n",
      " [12.40956893]\n",
      " [11.80476032]\n",
      " [10.26509741]\n",
      " [ 6.70547979]\n",
      " [ 7.87427247]\n",
      " [13.96770079]\n",
      " [ 2.18320136]\n",
      " [ 2.2376628 ]\n",
      " [ 8.59867602]\n",
      " [ 2.27307024]\n",
      " [ 1.92242838]\n",
      " [ 3.85674046]]\n",
      "label: \n",
      " [[20.5]\n",
      " [43.8]\n",
      " [34.9]\n",
      " [24.3]\n",
      " [24.6]\n",
      " [19.6]\n",
      " [26.7]\n",
      " [27.5]\n",
      " [21.6]\n",
      " [23.9]\n",
      " [50. ]\n",
      " [13.8]\n",
      " [24.4]\n",
      " [13.8]\n",
      " [19.9]\n",
      " [20.4]]\n",
      "loss:  372.8848942837481\n",
      "第 14 次迭代\n",
      "predict: \n",
      "  [[ 3.42634542]\n",
      " [15.7125885 ]\n",
      " [17.30163982]\n",
      " [ 3.88168767]\n",
      " [ 2.36389392]\n",
      " [ 3.76024376]\n",
      " [14.07301244]\n",
      " [22.86293858]\n",
      " [ 4.58573397]\n",
      " [21.34313906]\n",
      " [16.24800987]\n",
      " [16.35135276]\n",
      " [16.85764174]\n",
      " [ 5.85086889]\n",
      " [12.22892037]\n",
      " [22.62143473]]\n",
      "label: \n",
      " [[16.2]\n",
      " [21.6]\n",
      " [24.3]\n",
      " [21.5]\n",
      " [10.2]\n",
      " [17.1]\n",
      " [24.6]\n",
      " [23.5]\n",
      " [19.3]\n",
      " [33.1]\n",
      " [22.6]\n",
      " [25.2]\n",
      " [31.7]\n",
      " [19.8]\n",
      " [16.6]\n",
      " [20.1]]\n",
      "loss:  113.84402361777401\n",
      "第 15 次迭代\n",
      "predict: \n",
      "  [[15.3607671 ]\n",
      " [ 4.52117789]\n",
      " [16.7606664 ]\n",
      " [10.8773677 ]\n",
      " [ 2.62719266]\n",
      " [10.77463281]\n",
      " [16.20257907]\n",
      " [ 3.65865184]\n",
      " [ 4.92224521]\n",
      " [ 6.33905658]\n",
      " [15.56044555]\n",
      " [16.33949867]\n",
      " [ 5.56933383]\n",
      " [ 2.61243538]\n",
      " [ 9.51546657]\n",
      " [ 7.12742809]]\n",
      "label: \n",
      " [[30.7]\n",
      " [19.3]\n",
      " [19.3]\n",
      " [21. ]\n",
      " [23.2]\n",
      " [20. ]\n",
      " [18.9]\n",
      " [19.4]\n",
      " [19.1]\n",
      " [22.3]\n",
      " [30.1]\n",
      " [25. ]\n",
      " [23.8]\n",
      " [13.8]\n",
      " [20.3]\n",
      " [17.5]]\n",
      "loss:  171.85087779773153\n",
      "第 16 次迭代\n",
      "predict: \n",
      "  [[30.3499307 ]\n",
      " [16.13496681]\n",
      " [ 4.64228001]\n",
      " [23.29189869]\n",
      " [30.06823638]\n",
      " [46.13279705]\n",
      " [19.20424104]\n",
      " [19.8300109 ]\n",
      " [ 4.06492789]\n",
      " [13.73961417]\n",
      " [ 7.60812001]\n",
      " [ 5.62281515]\n",
      " [37.13428436]\n",
      " [26.53859748]\n",
      " [40.1098364 ]\n",
      " [22.39659779]]\n",
      "label: \n",
      " [[27.9]\n",
      " [27.1]\n",
      " [13.2]\n",
      " [23.9]\n",
      " [23.7]\n",
      " [30.1]\n",
      " [20. ]\n",
      " [50. ]\n",
      " [14.3]\n",
      " [23. ]\n",
      " [16.6]\n",
      " [12.7]\n",
      " [29.1]\n",
      " [22. ]\n",
      " [32.2]\n",
      " [29.9]]\n",
      "loss:  120.86075055288418\n",
      "第 17 次迭代\n",
      "predict: \n",
      "  [[20.42464214]\n",
      " [ 3.46953736]\n",
      " [ 5.28036492]\n",
      " [ 3.88177991]\n",
      " [48.20549501]\n",
      " [17.68909017]\n",
      " [27.71250332]\n",
      " [20.71954875]\n",
      " [20.60368811]\n",
      " [18.43072675]\n",
      " [32.08278581]\n",
      " [ 9.2747561 ]\n",
      " [38.26382225]\n",
      " [33.85125692]\n",
      " [ 8.23516127]\n",
      " [27.11546238]]\n",
      "label: \n",
      " [[21.7]\n",
      " [ 7.2]\n",
      " [19.2]\n",
      " [ 5. ]\n",
      " [44. ]\n",
      " [22.6]\n",
      " [28.4]\n",
      " [28.1]\n",
      " [50. ]\n",
      " [22.6]\n",
      " [46.7]\n",
      " [21.5]\n",
      " [37.3]\n",
      " [32. ]\n",
      " [15.6]\n",
      " [22.9]]\n",
      "loss:  101.76931821620884\n",
      "第 18 次迭代\n",
      "predict: \n",
      "  [[52.74082809]\n",
      " [ 5.92014697]\n",
      " [15.0874128 ]\n",
      " [ 3.48068593]\n",
      " [35.23510442]\n",
      " [30.59946701]\n",
      " [37.62729867]\n",
      " [ 6.30928931]\n",
      " [27.45783626]\n",
      " [36.84772757]\n",
      " [43.07361193]\n",
      " [ 4.2757968 ]\n",
      " [16.08248249]\n",
      " [29.82253569]\n",
      " [21.85222871]\n",
      " [ 5.71577841]]\n",
      "label: \n",
      " [[30.3]\n",
      " [18.6]\n",
      " [23.1]\n",
      " [25. ]\n",
      " [19.4]\n",
      " [23.9]\n",
      " [23.3]\n",
      " [15.7]\n",
      " [21.7]\n",
      " [37.6]\n",
      " [35.1]\n",
      " [ 7.4]\n",
      " [22.4]\n",
      " [22.2]\n",
      " [19.5]\n",
      " [19.6]]\n",
      "loss:  136.50741167428268\n",
      "第 19 次迭代\n",
      "predict: \n",
      "  [[ 4.16238042]\n",
      " [26.71630346]\n",
      " [50.07503194]\n",
      " [17.83880617]\n",
      " [32.88882086]\n",
      " [17.77785466]\n",
      " [37.65145125]\n",
      " [31.42043619]\n",
      " [43.91401576]\n",
      " [39.3468875 ]\n",
      " [19.05483036]\n",
      " [31.7796472 ]\n",
      " [25.61903317]\n",
      " [ 4.73636612]\n",
      " [60.2798781 ]\n",
      " [23.86302783]]\n",
      "label: \n",
      " [[50. ]\n",
      " [23.1]\n",
      " [35.4]\n",
      " [21. ]\n",
      " [31. ]\n",
      " [18.5]\n",
      " [26.4]\n",
      " [20.7]\n",
      " [23.7]\n",
      " [24.8]\n",
      " [26.5]\n",
      " [22. ]\n",
      " [24.1]\n",
      " [15.6]\n",
      " [34.6]\n",
      " [25. ]]\n",
      "loss:  258.59325260163934\n",
      "第 20 次迭代\n",
      "predict: \n",
      "  [[ 6.40806875]\n",
      " [ 8.14103406]\n",
      " [50.70526953]\n",
      " [53.25992181]\n",
      " [12.5066156 ]\n",
      " [44.17628204]\n",
      " [ 7.59708408]\n",
      " [28.61377385]\n",
      " [12.8344624 ]\n",
      " [22.13084369]\n",
      " [ 4.77263658]\n",
      " [12.65077715]\n",
      " [33.94453503]\n",
      " [ 8.91498223]\n",
      " [41.99670098]\n",
      " [33.66174561]]\n",
      "label: \n",
      " [[14.6]\n",
      " [18.7]\n",
      " [24.5]\n",
      " [24.8]\n",
      " [17.8]\n",
      " [33.2]\n",
      " [13.5]\n",
      " [22.5]\n",
      " [18.2]\n",
      " [41.3]\n",
      " [ 5.6]\n",
      " [19.6]\n",
      " [24.2]\n",
      " [18.8]\n",
      " [37.9]\n",
      " [20.1]]\n",
      "loss:  170.91245942896578\n",
      "第 21 次迭代\n",
      "predict: \n",
      "  [[ 7.51982399]\n",
      " [36.43994646]\n",
      " [ 9.39913181]\n",
      " [17.76902349]\n",
      " [ 6.7840239 ]\n",
      " [25.72707659]\n",
      " [28.9095405 ]\n",
      " [11.04265051]\n",
      " [31.56800264]\n",
      " [42.94384378]\n",
      " [36.62877192]\n",
      " [73.66613182]\n",
      " [ 5.15626531]\n",
      " [40.42905809]\n",
      " [16.97705911]\n",
      " [31.00367762]]\n",
      "label: \n",
      " [[23.7]\n",
      " [23.9]\n",
      " [14.5]\n",
      " [21.7]\n",
      " [17.8]\n",
      " [17.4]\n",
      " [22.9]\n",
      " [20. ]\n",
      " [27.5]\n",
      " [22.3]\n",
      " [24.8]\n",
      " [50. ]\n",
      " [20.8]\n",
      " [23.6]\n",
      " [16.2]\n",
      " [23. ]]\n",
      "loss:  156.430580334729\n",
      "第 22 次迭代\n",
      "predict: \n",
      "  [[42.24535476]\n",
      " [25.80942627]\n",
      " [37.99583117]\n",
      " [36.99005134]\n",
      " [ 4.80219444]\n",
      " [13.11966582]\n",
      " [10.76190683]\n",
      " [ 7.18002189]\n",
      " [ 7.88779329]\n",
      " [36.27016424]\n",
      " [31.24929496]\n",
      " [43.49621474]\n",
      " [32.39156334]\n",
      " [49.44337686]\n",
      " [ 9.84232007]\n",
      " [ 9.6263306 ]]\n",
      "label: \n",
      " [[36.1]\n",
      " [23.8]\n",
      " [20.5]\n",
      " [16.5]\n",
      " [50. ]\n",
      " [20.1]\n",
      " [21.4]\n",
      " [13.1]\n",
      " [17.8]\n",
      " [25. ]\n",
      " [23.4]\n",
      " [37.2]\n",
      " [25.1]\n",
      " [43.5]\n",
      " [22. ]\n",
      " [13.1]]\n",
      "loss:  223.90123317493803\n",
      "第 23 次迭代\n",
      "predict: \n",
      "  [[ 5.83054918]\n",
      " [23.46911496]\n",
      " [28.8694545 ]\n",
      " [55.75236995]\n",
      " [11.15386535]\n",
      " [43.55557983]\n",
      " [16.01291497]\n",
      " [49.72864006]\n",
      " [19.37585186]\n",
      " [44.71739422]\n",
      " [42.92334736]\n",
      " [ 5.5161123 ]\n",
      " [35.20496426]\n",
      " [44.08470009]\n",
      " [46.50959865]\n",
      " [47.1126415 ]]\n",
      "label: \n",
      " [[17.8]\n",
      " [24.7]\n",
      " [23.8]\n",
      " [42.3]\n",
      " [17.3]\n",
      " [25. ]\n",
      " [14.5]\n",
      " [31.1]\n",
      " [18.9]\n",
      " [24.7]\n",
      " [30.5]\n",
      " [10.9]\n",
      " [29.1]\n",
      " [37. ]\n",
      " [36.2]\n",
      " [35.2]]\n",
      "loss:  125.1724276292895\n",
      "第 24 次迭代\n",
      "predict: \n",
      "  [[53.8489388 ]\n",
      " [ 6.17578873]\n",
      " [20.43838392]\n",
      " [ 6.75060408]\n",
      " [30.98104724]\n",
      " [49.50330757]\n",
      " [ 9.07308624]\n",
      " [32.60323789]\n",
      " [10.19601366]\n",
      " [10.165143  ]\n",
      " [40.12456003]\n",
      " [32.75620072]\n",
      " [ 8.95151947]\n",
      " [25.62368353]\n",
      " [27.1867051 ]\n",
      " [ 9.22447066]]\n",
      "label: \n",
      " [[48.5]\n",
      " [ 8.5]\n",
      " [20. ]\n",
      " [50. ]\n",
      " [22.2]\n",
      " [50. ]\n",
      " [20.2]\n",
      " [29. ]\n",
      " [18. ]\n",
      " [15.6]\n",
      " [28.7]\n",
      " [22. ]\n",
      " [14.4]\n",
      " [26.4]\n",
      " [23.2]\n",
      " [17.4]]\n",
      "loss:  160.53366023667493\n",
      "第 25 次迭代\n",
      "predict: \n",
      "  [[26.35179795]\n",
      " [33.72632572]\n",
      " [22.43589263]\n",
      " [36.01933366]\n",
      " [ 7.31333779]\n",
      " [29.3167926 ]\n",
      " [ 6.81857542]\n",
      " [18.60346132]\n",
      " [ 6.65824189]\n",
      " [ 7.08447017]\n",
      " [19.53113208]\n",
      " [32.39918138]\n",
      " [31.30862001]\n",
      " [23.52203908]\n",
      " [27.95400129]\n",
      " [25.93486087]]\n",
      "label: \n",
      " [[23.3]\n",
      " [24.1]\n",
      " [23.3]\n",
      " [33.1]\n",
      " [10.5]\n",
      " [22.8]\n",
      " [ 9.7]\n",
      " [20.3]\n",
      " [22.7]\n",
      " [10.2]\n",
      " [20.8]\n",
      " [25.3]\n",
      " [31.2]\n",
      " [27. ]\n",
      " [19.7]\n",
      " [19.3]]\n",
      "loss:  38.647744930820615\n",
      "第 26 次迭代\n",
      "predict: \n",
      "  [[31.00681593]\n",
      " [28.79320213]\n",
      " [22.16456716]\n",
      " [30.09764017]\n",
      " [11.26982968]\n",
      " [35.81948678]\n",
      " [22.15405451]\n",
      " [20.38402177]\n",
      " [16.92046897]\n",
      " [14.35311507]\n",
      " [20.51724021]\n",
      " [16.19882475]\n",
      " [22.10920757]\n",
      " [29.67721315]\n",
      " [20.96029768]\n",
      " [34.36988202]]\n",
      "label: \n",
      " [[21.9]\n",
      " [24.8]\n",
      " [17.1]\n",
      " [36.2]\n",
      " [14. ]\n",
      " [50. ]\n",
      " [19. ]\n",
      " [18.2]\n",
      " [19.9]\n",
      " [20.5]\n",
      " [21.1]\n",
      " [21.9]\n",
      " [21.7]\n",
      " [23.6]\n",
      " [18.4]\n",
      " [22. ]]\n",
      "loss:  41.3250671995126\n",
      "第 27 次迭代\n",
      "predict: \n",
      "  [[17.30002465]\n",
      " [43.49331696]\n",
      " [20.10642098]\n",
      " [32.76564781]\n",
      " [ 7.219066  ]\n",
      " [14.79891302]\n",
      " [12.58052048]\n",
      " [21.00092919]\n",
      " [45.89201153]\n",
      " [20.67682009]\n",
      " [22.1666368 ]\n",
      " [14.95561521]\n",
      " [29.30980058]\n",
      " [20.95165751]\n",
      " [23.2329031 ]\n",
      " [ 8.2745282 ]]\n",
      "label: \n",
      " [[20.3]\n",
      " [39.8]\n",
      " [23.3]\n",
      " [24. ]\n",
      " [ 6.3]\n",
      " [17.4]\n",
      " [17. ]\n",
      " [18.5]\n",
      " [48.8]\n",
      " [19.4]\n",
      " [15. ]\n",
      " [15.2]\n",
      " [22.6]\n",
      " [17.2]\n",
      " [17.5]\n",
      " [13.3]]\n",
      "loss:  20.112358038199385\n",
      "第 28 次迭代\n",
      "predict: \n",
      "  [[13.50928219]\n",
      " [35.47953759]\n",
      " [23.08814586]\n",
      " [26.11824433]\n",
      " [29.79939255]\n",
      " [16.88532995]\n",
      " [36.26859729]\n",
      " [15.69219565]\n",
      " [13.25890697]\n",
      " [35.22205434]\n",
      " [45.21692501]\n",
      " [21.94513695]\n",
      " [ 8.19443006]\n",
      " [27.68981889]\n",
      " [26.28442459]\n",
      " [11.14577053]]\n",
      "label: \n",
      " [[18.1]\n",
      " [32.7]\n",
      " [23.2]\n",
      " [22. ]\n",
      " [28.6]\n",
      " [19.2]\n",
      " [34.7]\n",
      " [19.4]\n",
      " [19.5]\n",
      " [33.4]\n",
      " [46. ]\n",
      " [22.8]\n",
      " [13.1]\n",
      " [29.4]\n",
      " [29.6]\n",
      " [13.9]]\n",
      "loss:  9.872549687468162\n",
      "第 29 次迭代\n",
      "predict: \n",
      "  [[26.46218169]\n",
      " [43.14840982]\n",
      " [ 8.77049108]\n",
      " [27.9846049 ]\n",
      " [13.27081957]\n",
      " [40.26150736]\n",
      " [32.52372399]\n",
      " [27.26617904]\n",
      " [24.82415806]\n",
      " [13.71667094]\n",
      " [23.14015922]\n",
      " [19.3964506 ]\n",
      " [10.20181911]\n",
      " [38.69215131]\n",
      " [17.56891458]\n",
      " [18.30927306]]\n",
      "label: \n",
      " [[28.2]\n",
      " [50. ]\n",
      " [12.1]\n",
      " [27.1]\n",
      " [18.3]\n",
      " [41.7]\n",
      " [33.8]\n",
      " [28.7]\n",
      " [22.2]\n",
      " [17.6]\n",
      " [28.4]\n",
      " [24.4]\n",
      " [23.1]\n",
      " [48.3]\n",
      " [19.4]\n",
      " [19.4]]\n",
      "loss:  26.922710214221567\n",
      "第 30 次迭代\n",
      "predict: \n",
      "  [[32.2717492 ]\n",
      " [13.76326465]\n",
      " [ 9.92697053]\n",
      " [10.95993888]\n",
      " [19.77170065]\n",
      " [21.47233939]\n",
      " [16.49870799]\n",
      " [11.03386438]\n",
      " [22.52427719]\n",
      " [21.87815227]\n",
      " [15.79373822]\n",
      " [12.06424802]\n",
      " [15.09468061]\n",
      " [18.55553475]\n",
      " [19.02154834]\n",
      " [29.88000931]]\n",
      "label: \n",
      " [[30.1]\n",
      " [21.7]\n",
      " [21.7]\n",
      " [50. ]\n",
      " [23.4]\n",
      " [23.1]\n",
      " [21.2]\n",
      " [16.8]\n",
      " [28. ]\n",
      " [21.2]\n",
      " [21.7]\n",
      " [20.4]\n",
      " [18.8]\n",
      " [24.3]\n",
      " [24. ]\n",
      " [33. ]]\n",
      "loss:  126.10392969191778\n",
      "第 31 次迭代\n",
      "predict: \n",
      "  [[20.37809069]\n",
      " [ 9.14002324]\n",
      " [10.10094867]\n",
      " [21.46473155]\n",
      " [40.67423241]\n",
      " [ 8.83656305]\n",
      " [17.66148375]\n",
      " [17.99472375]\n",
      " [20.98284226]\n",
      " [17.3972139 ]\n",
      " [14.00174496]\n",
      " [ 9.73455708]\n",
      " [11.2995774 ]\n",
      " [34.77361068]\n",
      " [10.27392448]\n",
      " [28.15727751]]\n",
      "label: \n",
      " [[30.8]\n",
      " [22.6]\n",
      " [11.3]\n",
      " [34.9]\n",
      " [38.7]\n",
      " [18.5]\n",
      " [20.3]\n",
      " [20.6]\n",
      " [32. ]\n",
      " [22.8]\n",
      " [16. ]\n",
      " [16.1]\n",
      " [16.5]\n",
      " [36. ]\n",
      " [ 7.2]\n",
      " [42.8]]\n",
      "loss:  64.39041743136755\n",
      "第 32 次迭代\n",
      "predict: \n",
      "  [[22.70076737]\n",
      " [43.99243961]\n",
      " [17.78464846]\n",
      " [19.28900579]\n",
      " [18.6123286 ]\n",
      " [23.38934077]\n",
      " [10.6160505 ]\n",
      " [23.07082616]\n",
      " [13.15081466]\n",
      " [19.09499818]\n",
      " [18.87743253]\n",
      " [15.08052663]\n",
      " [15.50103275]\n",
      " [20.39430472]\n",
      " [16.81896764]\n",
      " [20.97767479]]\n",
      "label: \n",
      " [[29. ]\n",
      " [50. ]\n",
      " [20.4]\n",
      " [11.8]\n",
      " [18.2]\n",
      " [31.6]\n",
      " [11.5]\n",
      " [34.9]\n",
      " [23.1]\n",
      " [15.4]\n",
      " [13.8]\n",
      " [20.9]\n",
      " [19.8]\n",
      " [22.1]\n",
      " [26.2]\n",
      " [32.9]]\n",
      "loss:  48.17587870744783\n",
      "第 33 次迭代\n",
      "predict: \n",
      "  [[23.13964599]\n",
      " [19.90173048]\n",
      " [16.1782039 ]\n",
      " [14.53639136]\n",
      " [17.67838782]\n",
      " [21.01262866]\n",
      " [18.21164171]\n",
      " [19.93231306]\n",
      " [25.69908484]\n",
      " [17.22800187]\n",
      " [26.86527006]\n",
      " [16.2927285 ]\n",
      " [11.17511083]\n",
      " [30.59735082]\n",
      " [13.45848273]\n",
      " [38.24395361]]\n",
      "label: \n",
      " [[36.4]\n",
      " [18.9]\n",
      " [20.4]\n",
      " [14.8]\n",
      " [19.3]\n",
      " [25. ]\n",
      " [19.6]\n",
      " [24.4]\n",
      " [28.7]\n",
      " [23.8]\n",
      " [33.3]\n",
      " [21. ]\n",
      " [15. ]\n",
      " [33.2]\n",
      " [10.4]\n",
      " [50. ]]\n",
      "loss:  32.49208344220558\n",
      "第 34 次迭代\n",
      "predict: \n",
      "  [[17.92203878]\n",
      " [16.4894593 ]\n",
      " [14.65992683]\n",
      " [30.15123098]\n",
      " [16.40907024]\n",
      " [11.41440268]\n",
      " [10.37393223]\n",
      " [10.38724472]\n",
      " [12.71175482]\n",
      " [19.39430915]\n",
      " [12.37985637]\n",
      " [28.76969617]\n",
      " [21.06807377]\n",
      " [15.46015733]\n",
      " [18.47786182]\n",
      " [36.07891794]]\n",
      "label: \n",
      " [[15.6]\n",
      " [21.1]\n",
      " [13.6]\n",
      " [31.5]\n",
      " [22.9]\n",
      " [15.1]\n",
      " [20.6]\n",
      " [18.7]\n",
      " [ 8.8]\n",
      " [15.3]\n",
      " [12.3]\n",
      " [32.4]\n",
      " [15.6]\n",
      " [23.1]\n",
      " [26.6]\n",
      " [45.4]]\n",
      "loss:  34.08444587008295\n",
      "第 35 次迭代\n",
      "predict: \n",
      "  [[12.94986388]\n",
      " [11.04876341]\n",
      " [13.76130147]\n",
      " [11.17524545]\n",
      " [16.47451385]\n",
      " [33.31209826]\n",
      " [22.4060321 ]\n",
      " [10.85724459]\n",
      " [13.91414312]\n",
      " [18.86717903]\n",
      " [20.56942973]\n",
      " [18.89183759]\n",
      " [38.84338354]\n",
      " [36.38948617]\n",
      " [12.35820922]\n",
      " [19.36869188]]\n",
      "label: \n",
      " [[20.9]\n",
      " [12.7]\n",
      " [27.5]\n",
      " [18.6]\n",
      " [21.4]\n",
      " [36.5]\n",
      " [22.7]\n",
      " [13.1]\n",
      " [24.5]\n",
      " [18.7]\n",
      " [18.5]\n",
      " [22.9]\n",
      " [44.8]\n",
      " [50. ]\n",
      " [12.5]\n",
      " [13.4]]\n",
      "loss:  46.13490863192529\n",
      "第 36 次迭代\n",
      "predict: \n",
      "  [[16.79168297]\n",
      " [26.4134518 ]\n",
      " [17.94066121]\n",
      " [24.11462676]\n",
      " [15.76839973]\n",
      " [30.15096376]\n",
      " [19.28877515]\n",
      " [16.09222952]\n",
      " [15.46276243]\n",
      " [20.28571375]\n",
      " [22.76796563]\n",
      " [16.79993856]\n",
      " [23.81447038]\n",
      " [34.31937552]\n",
      " [19.73320866]\n",
      " [35.77247304]]\n",
      "label: \n",
      " [[18.9]\n",
      " [26.6]\n",
      " [22.5]\n",
      " [20.7]\n",
      " [22.2]\n",
      " [31.6]\n",
      " [20.6]\n",
      " [22.5]\n",
      " [19.8]\n",
      " [31.5]\n",
      " [26.6]\n",
      " [24.1]\n",
      " [21.4]\n",
      " [32.5]\n",
      " [25. ]\n",
      " [43.1]]\n",
      "loss:  26.64311472475081\n",
      "第 37 次迭代\n",
      "predict: \n",
      "  [[19.48575572]\n",
      " [11.8559112 ]\n",
      " [18.02700987]\n",
      " [28.44033917]\n",
      " [22.48889301]\n",
      " [29.65981749]\n",
      " [15.53534177]\n",
      " [29.89873947]\n",
      " [14.83170806]\n",
      " [13.52501364]\n",
      " [16.20559377]\n",
      " [13.03729506]\n",
      " [18.85687816]\n",
      " [24.98544111]\n",
      " [23.14519199]\n",
      " [19.60120544]]\n",
      "label: \n",
      " [[22.8]\n",
      " [13.9]\n",
      " [18.4]\n",
      " [33.4]\n",
      " [35.4]\n",
      " [50. ]\n",
      " [21.2]\n",
      " [50. ]\n",
      " [22.2]\n",
      " [14.4]\n",
      " [19.5]\n",
      " [21.9]\n",
      " [29.8]\n",
      " [29.6]\n",
      " [24.4]\n",
      " [13.3]]\n",
      "loss:  86.4532742525436\n",
      "第 38 次迭代\n",
      "predict: \n",
      "  [[19.05034447]\n",
      " [20.29781449]\n",
      " [21.8677701 ]\n",
      " [11.9723402 ]\n",
      " [15.33760072]\n",
      " [35.47905295]\n",
      " [22.40429207]\n",
      " [28.94071637]\n",
      " [17.85192227]\n",
      " [11.89915986]\n",
      " [32.96846265]\n",
      " [24.05532528]\n",
      " [25.37221849]\n",
      " [21.94803075]\n",
      " [14.95896775]\n",
      " [14.99473497]]\n",
      "label: \n",
      " [[24.7]\n",
      " [28.5]\n",
      " [23. ]\n",
      " [ 8.3]\n",
      " [20.5]\n",
      " [43.8]\n",
      " [34.9]\n",
      " [24.3]\n",
      " [24.6]\n",
      " [19.6]\n",
      " [26.7]\n",
      " [27.5]\n",
      " [21.6]\n",
      " [23.9]\n",
      " [50. ]\n",
      " [13.8]]\n",
      "loss:  111.92937638559735\n",
      "第 39 次迭代\n",
      "predict: \n",
      "  [[24.35408571]\n",
      " [15.81749402]\n",
      " [14.22040151]\n",
      " [19.74189984]\n",
      " [19.75367207]\n",
      " [28.21877013]\n",
      " [12.60729342]\n",
      " [26.11822653]\n",
      " [12.5784183 ]\n",
      " [22.44680626]\n",
      " [27.68830035]\n",
      " [18.28314054]\n",
      " [19.80497141]\n",
      " [29.95883059]\n",
      " [16.96471162]\n",
      " [17.4090444 ]]\n",
      "label: \n",
      " [[24.4]\n",
      " [13.8]\n",
      " [19.9]\n",
      " [20.4]\n",
      " [16.2]\n",
      " [21.6]\n",
      " [24.3]\n",
      " [21.5]\n",
      " [10.2]\n",
      " [17.1]\n",
      " [24.6]\n",
      " [23.5]\n",
      " [19.3]\n",
      " [33.1]\n",
      " [22.6]\n",
      " [25.2]]\n",
      "loss:  26.551454374443107\n",
      "第 40 次迭代\n",
      "predict: \n",
      "  [[33.82596697]\n",
      " [24.80640787]\n",
      " [21.72840054]\n",
      " [14.66452755]\n",
      " [30.36043168]\n",
      " [14.72178071]\n",
      " [14.00199636]\n",
      " [24.37536569]\n",
      " [12.45427848]\n",
      " [19.93862606]\n",
      " [19.49749642]\n",
      " [23.42629495]\n",
      " [24.4642627 ]\n",
      " [27.49663842]\n",
      " [28.25017455]\n",
      " [17.49857017]]\n",
      "label: \n",
      " [[31.7]\n",
      " [19.8]\n",
      " [16.6]\n",
      " [20.1]\n",
      " [30.7]\n",
      " [19.3]\n",
      " [19.3]\n",
      " [21. ]\n",
      " [23.2]\n",
      " [20. ]\n",
      " [18.9]\n",
      " [19.4]\n",
      " [19.1]\n",
      " [22.3]\n",
      " [30.1]\n",
      " [25. ]]\n",
      "loss:  24.592628909822846\n",
      "第 41 次迭代\n",
      "predict: \n",
      "  [[27.16629382]\n",
      " [14.16465007]\n",
      " [22.9323615 ]\n",
      " [18.10925959]\n",
      " [19.92079382]\n",
      " [19.96779303]\n",
      " [12.73403659]\n",
      " [24.16768425]\n",
      " [19.66194943]\n",
      " [22.96624702]\n",
      " [16.62785994]\n",
      " [48.68121204]\n",
      " [19.83458606]\n",
      " [32.91487124]\n",
      " [17.46794466]\n",
      " [15.03474163]]\n",
      "label: \n",
      " [[23.8]\n",
      " [13.8]\n",
      " [20.3]\n",
      " [17.5]\n",
      " [27.9]\n",
      " [27.1]\n",
      " [13.2]\n",
      " [23.9]\n",
      " [23.7]\n",
      " [30.1]\n",
      " [20. ]\n",
      " [50. ]\n",
      " [14.3]\n",
      " [23. ]\n",
      " [16.6]\n",
      " [12.7]]\n",
      "loss:  21.814861713585117\n",
      "第 42 次迭代\n",
      "predict: \n",
      "  [[24.11370568]\n",
      " [19.63937925]\n",
      " [24.79508965]\n",
      " [33.12825698]\n",
      " [13.8511028 ]\n",
      " [14.32490921]\n",
      " [23.90010189]\n",
      " [13.74206056]\n",
      " [29.20376839]\n",
      " [15.75117624]\n",
      " [29.14749879]\n",
      " [20.97813681]\n",
      " [46.76219204]\n",
      " [23.96657963]\n",
      " [30.70300768]\n",
      " [27.36267857]]\n",
      "label: \n",
      " [[29.1]\n",
      " [22. ]\n",
      " [32.2]\n",
      " [29.9]\n",
      " [21.7]\n",
      " [ 7.2]\n",
      " [19.2]\n",
      " [ 5. ]\n",
      " [44. ]\n",
      " [22.6]\n",
      " [28.4]\n",
      " [28.1]\n",
      " [50. ]\n",
      " [22.6]\n",
      " [46.7]\n",
      " [21.5]]\n",
      "loss:  57.894646934939644\n",
      "第 43 次迭代\n",
      "predict: \n",
      "  [[24.73046413]\n",
      " [25.92216297]\n",
      " [18.96258233]\n",
      " [18.04959422]\n",
      " [26.58375249]\n",
      " [17.45066929]\n",
      " [22.6363405 ]\n",
      " [15.54208916]\n",
      " [16.72452995]\n",
      " [21.31568459]\n",
      " [21.63693505]\n",
      " [23.45542675]\n",
      " [21.60185465]\n",
      " [43.52228974]\n",
      " [29.71946801]\n",
      " [14.41244495]]\n",
      "label: \n",
      " [[37.3]\n",
      " [32. ]\n",
      " [15.6]\n",
      " [22.9]\n",
      " [30.3]\n",
      " [18.6]\n",
      " [23.1]\n",
      " [25. ]\n",
      " [19.4]\n",
      " [23.9]\n",
      " [23.3]\n",
      " [15.7]\n",
      " [21.7]\n",
      " [37.6]\n",
      " [35.1]\n",
      " [ 7.4]]\n",
      "loss:  32.78263484272694\n",
      "第 44 次迭代\n",
      "predict: \n",
      "  [[23.74508463]\n",
      " [22.05296761]\n",
      " [18.33898006]\n",
      " [25.09254936]\n",
      " [23.23164869]\n",
      " [23.83969218]\n",
      " [31.60753156]\n",
      " [20.13467891]\n",
      " [37.74861394]\n",
      " [18.75525241]\n",
      " [19.68644015]\n",
      " [21.46494148]\n",
      " [20.71089332]\n",
      " [18.72134679]\n",
      " [26.30090051]\n",
      " [14.69607292]]\n",
      "label: \n",
      " [[22.4]\n",
      " [22.2]\n",
      " [19.5]\n",
      " [19.6]\n",
      " [50. ]\n",
      " [23.1]\n",
      " [35.4]\n",
      " [21. ]\n",
      " [31. ]\n",
      " [18.5]\n",
      " [26.4]\n",
      " [20.7]\n",
      " [23.7]\n",
      " [24.8]\n",
      " [26.5]\n",
      " [22. ]]\n",
      "loss:  59.75675319279509\n"
     ]
    }
   ],
   "source": [
    "for i in range(num_iterations):\n",
    "    print('第', i, '次迭代')\n",
    "    opt.update_iteration(i)\n",
    "    data_handler.pull_data()\n",
    "    dnn.forward_train(data_handler.output_sample,data_handler.output_label)\n",
    "    dnn.backward_train()\n",
    "    dnn.update()\n",
    "    train_error.append(dnn.loss.loss)\n",
    "    if max_loss >  dnn.loss.loss:\n",
    "        early_stopping_mark = 0\n",
    "        max_loss = dnn.loss.loss\n",
    "    if early_stopping_mark > early_stopping_iter:\n",
    "        break\n",
    "    early_stopping_mark += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl41NXZ8PHvyTLZJgtJJiELkAQCBFllVdzBVq37pd1otZY+1tY+tbWb7fO+b5fn6Xa1VbtY+1C12optrdqKli7KoqISCCiLRAgkBEJCtskyScgkkznvH/ObMJBJMpksM7/J/bkursz85szk5KfcOdznnPsorTVCCCEiV1SoOyCEEGJ8SaAXQogIJ4FeCCEinAR6IYSIcBLohRAiwkmgF0KICCeBXgghIpwEeiGEiHAS6IUQIsLFhLoDAJmZmbqgoCDU3RBCCFPZs2dPk9baNly7sAj0BQUFlJWVhbobQghhKkqp6kDaSepGCCEinAR6IYSIcBLohRAiwkmgF0KICCeBXgghIpwEeiGEiHAS6IUQIsJJoB+B0spmDtS0hbobQggxIhLoA9Td28dnn97Dj/5ZHuquCCHEiEigD9A/DtbR2tVLQ7sz1F0RQogRkUAfoI07TwDQ1CGBXghhLhLoA3D4tIOy6hYyrRZaunrp7XOHuktCCBEwCfQBeKa0Gkt0FHdeVABAc0dPaDskhBAjIIF+GF09Ll7Ye4rrFkylODsZkPSNEMJcwqJMcTh7aV8tDqeLj6+cQbTxa7FRAr0QwkRkRD+MZ0pPUJxlZXnBFGzWeACaHBLohRDmIYF+CAdq2thX08a6ldNRSpGZbAFkRC+EMBcJ9EN4Zlc18bFR3HJhPgCJlhgSLdE0OWQyVghhHgEFeqXUl5VS7ymlDiql/qiUildKFSqlSpVSFUqpPyulLEbbOOP5UeP1gvH8AcaLo7uXF9+t5YaFuaQmxPZftyXHyWSsEMJUhg30Sqk84IvAMq31fCAa+CjwY+AhrXUx0AKsN96yHmjRWs8CHjLamc7f3q2lq6ePdatmnHM90yqBXghhLoGmbmKABKVUDJAI1AFXAc8Zrz8F3Gw8vsl4jvH6GqWUGpvuTgytNRt3VnNBbgqL8lPPeS3TaqFRJmOFECYybKDXWp8CfgqcwBPg24A9QKvW2mU0qwHyjMd5wEnjvS6jfcbYdnt8vXOylfdPO/i4MQnrS0b0QgizCSR1MwXPKL0QyAWSgGv9NNXetwzxmu/n3q2UKlNKlTU2Ngbe4wmwcecJkizR3LQ4b8BrtuQ4KYMghDCVQFI3a4EqrXWj1roXeAG4GEgzUjkA+UCt8bgGmAZgvJ4K2M//UK31Bq31Mq31MpvNNsofY+y0dfXy8v5abl6ShzVu4H6yTGscIGUQhBDmEUigPwGsUkolGrn2NcAhYBtwm9HmTuBF4/Em4znG61u11gNG9OHq+b01OF1u1q2c4fd1b6CX9I0QwiwCydGX4plU3QscMN6zAfgGcL9S6iieHPzjxlseBzKM6/cDD4xDv8fNtsMNzJ2azLzcFL+v25I9gV42TQkhzCKgWjda628D3z7vciWwwk/bbuD20XctNOrbuynISBr0dZt3RC8rb4QQJiE7Y8/T4HCSnRI/6OtSBkEIYTYS6H04XX20dvWSZaRn/JEyCEIIs5FA78O7ESorZfBAD1IGQQhhLhLofdS3ewP94Kkb8Ky8kd2xQgizkEDvo9HRDTBk6gY8ZRBkRC+EMAsJ9D4avKmb5KFH9JK6EUKYiQR6H/Xt3URHKTKSLEO2y7RKGQQhhHlIoPfR0O7EZo0jKmroYptSBkEIYSYS6H00OJzDrrgBKYMghDAXCfQ+6tu7h52IBSmDIIQwFwn0PhodzmGXVsLZMgiyxFIIYQYS6A29fW6aO3sCGtF7yyBI6kYIYQYS6A3eoD3c0krwlEFIkjIIQgiTkEBv6N8VG8CIHiBT1tILIUxCAr2hod2zK3aoypW+pAyCEMIsJNAbGgIsaOYlZRCEEGYhgd7Q0N6NUgy7K9ZLyiAIIcxCAr2hweEk0xpHTHRgt0TKIAghzEICvaHB4Qx4IhakDIIQwjwk0BsC3RXr5d0dK+kbIUS4k0Bv8IzoA1txA2dH9FIGQQgR7iTQA31uTXOHk+wAV9yAlEEQQpiHBHqgucOJW4MtwDX0IGUQhBDmIYGeke+KBSmDIIQwDwn0QINjZLtivaQMghDCDCTQ43tWbOAjepAyCEIIc5BAj2dpJZxdSRMom1VG9EKI8CeBHs+IPiPJgiVmZLcjM1nq3Qghwp8EeoxDwUeYtoHAyyD898uH2FHRFGz3hBBiVCTQ45mMDeQIwfMFUgbhdFs3j++o4vm9NUH3TwghRkMCPZ4R/UgnYiGwMgilVc0AVDZ2BNc5IYQYpUkf6N1uTdMId8V6BVIGobTKDkBlYyda6+A6KYQQozDpA729qweXW4+ozo1XIGUQdhmB3uF0SV0cIURIRGygP9HcxT8P1g3bzru0MpjUzXBlEJo6nBxt6GD1rAwAjjV0jvh7CCHEaEVsoH94yxE+v3EvXT2uIdudPUJw5CP64cogeEfzH1sxHYDKJsnTCyEmXsQG+t3H7bg1HKptH7JdYxB1bnxlJscNmpLZVWUnITaaq+dlEx8bRWWjjOiFEBMvIgP96bZuTtrPALC/pm3Itt7UTTDr6MEzIds0SI5+Z2UzS2dMIS4mmsJMq6y8EUKEREQG+l3HPSmT6CjFgVNDB/oGh5PUhFjiY6OD+l6DlUFo7erhcL2DFYXpABTZkqhskhG9EGLiBRTolVJpSqnnlFLvK6XKlVIXKaXSlVKvKKUqjK9TjLZKKfULpdRRpdR+pdSF4/sjDLS7yk6SJZrLijPZX9M6ZNsGR3dQSyu9BiuDsPt4C1rDSiPQz8xM4qS9C6erL+jvJYQQwQh0RP9z4J9a67nAIqAceADYorUuBrYYzwGuBYqNP3cDj45pjwOw+7idC2dMYcn0KVQ2deLo7h20bX37yI4QPN9gZRB2VTVjiYli0bQ0AIpsVtwaqpu7gv5eQggRjGEDvVIqBbgMeBxAa92jtW4FbgKeMpo9BdxsPL4J+L322AmkKaVyxrzng2jr6uVwvYPlBeksyE9Fa3hviAnZRkdwu2K9vLn988sglFbZWTwtrT8lNNNmBWSHrBBi4gUyoi8CGoHfKaXeUUo9ppRKArK11nUAxtcso30ecNLn/TXGtQlRVm1HazyBPi8VgAODTMhqrT2BPoillV7e3bG+6RtHdy8HT7WxykjbABTakgA4JitvhBATLJBAHwNcCDyqtV4CdHI2TeOP8nNtwN5/pdTdSqkypVRZY2NjQJ0NxK7jdmKjFUump5FpjSMvLYH9g0zItnb10tPnHtWIPtPP7tg91S24NawozOi/Zo2LITslTpZYCiEmXCCBvgao0VqXGs+fwxP4670pGeNrg0/7aT7vzwdqz/9QrfUGrfUyrfUym80WbP8H2F1lZ0Fean/KZEFeKgcGmZCtN44QzBrFZKzNT72b0io7MVGKC2ekndO2KNMqm6aEEBNu2ECvtT4NnFRKzTEurQEOAZuAO41rdwIvGo83AXcYq29WAW3eFM946+7t48CpNpb7pEwW5KdyvLmLtjMDJ2Qb+jdLjSJ146cMwq4qOwvyU0m0xJzTtsiWJMXNhBATLmb4JgD8J7BRKWUBKoG78PySeFYptR44AdxutN0MXAccBbqMthPi3ZOt9PZpVhScDfQL8z15+vdOtXHxrMxz2nvLH4xmeeX5ZRDO9PSxv6aV9ZcUDWhbZLPSdqaX5s6eER9bKIQQwQoo0Gut3wWW+XlpjZ+2Grh3lP0Kyu4qO0rBshk+I3pjQna/n0B/tqBZ8CN6OLcMwjsnWujt0/3r533NNCZkKxs7JdALISZMRO2M3XXczpzsZFITY/uvpSVamJ6e6HflTaPDSXJcDAmW4HbFetl8yiDsrLITpWBpwZQB7WSJpRAiFCIm0Lv63OytbmF5wcCR9IL8VPafGjgh6zlCcPQj60yfMgillc3My00hJT52QLvctAQsMVFSCkEIMaEiJtCX1zno7Ok7ZyLWa0FeKiftZ2jpPHdT02h3xXp5yyA4XX28c7KVlT7LKn1FRykKM5JkRC+EmFARE+i9hcyW+0mZLPRunDpvPf1Yjeht1nhaunrZc7yFHpfbb37eq8iWJJumhBATKmIC/e4qO/lTEshJTRjw2gV+Ar3WOuhDwc/nXWK52TjRyl/6yKvIlsQJexc9LvegbYQQYixFRKDXWrP7uP2cZZW+UhNiKcxMOqeSZXu3C6fLTfYoyh94eVfQ/PPgaeZOTWZKkmXQtjNtVvrcmhN2KW4mhJgYERHoK5s6ae7s8Zuf9/LskD07om8Y5YEjvs7Wu+nprz8/mCJZeSOEmGAREeh3V3nz84MH2YX5qdS2dfevjuk/K3YMJmN90z+DTcR6FXnX0svKGyHEBImIQL/ruJ2MJEv/hiR/FpyXp28w6tyMZlesl+/mp+WFAyeDfaXEx5JpjZMRvRBiwkREoN993M6ygiko5a9wpscFeakodbZkcb23zs0Y5OgTLNEkWaIpsiUF9C8EWXkjhJhIpg/03oPAh0rbgKdM8Eybtf+w8IZ2J4mWaKxxgZb7GdrSgnSuXxDY+SozbYGtpT/a4MDVJ6tzhBCjY/pA710/P9wkKHjW0x8wdsg2OLrHZGml1+8/vYL7PzBn+IZ4Vt60dPUO2MDla39NK2sffJ0X3x1Q4VkIIUbE9IHeexD4vJyUYdsuyE+lvt1JfXu3Zw39GKRtgnF2QnbwUf3jO6qAgZu8hBBipMwf6I2DwGOih/9RvCWLD9S0jfmIfiSKMj1LLAfL059u6+bv+z2bryoaHBPWLyFEZDJ1oG87c/Yg8EDMy0klSnlKFjc4xqbOTTDypyQQG60GPVbw928fx609pY6P1MvqHCHE6Jg60O/xOQg8EAmWaGZnJ7PzWDNdPX1jsrQyGDHRUczISOKYnwnZMz19PLPrBB+YN5U1JVk0Opy0dg2eyxdCiOGYOtCX1zn6DwIP1Py8VMqqPRO4Y1HQLFhFmf5X3jy/t4bWrl4+fUkhxdnJADKqF0KMiqkD/b1XzqL0W2v7DwIPxML8VNzGka2hSt0AzMyycsLedc7ySbdb88SbVSzIS2V5wRRm9wd6ydMLIYJn6kAPkD5EATF/vDtkgZBNxoJnRN/bpznZcqb/2msVjVQ2drL+kkKUUuSmxmONi6FCAr0QYhRMH+hHqiQnhZgozw7aUC2vBP/FzZ7YUUV2ShzXGRuvlFLMyrJK6kYIMSqTLtDHx3omZONiokiJH5tdscHwPSgc4PBpB29UNHHHRQVYYs7+Z5mdbZUllkKIUZl0gR7g8jk2LshNGbI2znhLS7SQnmTpX3nzuzeriI+N4uMrpp/TbnZ2Mk0dPdiH2EUrhBBDCd2QNoS+/sHAShWMN8/Km06aO5y88M4pbluaP+DQkmKfCdlVRUOXQBZCCH8m5YheKRXS0bzXTJuVyqYONpaeoMfl5tOrCwa0mZ3tyeXLhKwQIliTMtCHiyJbEk0dPfzuzSoun21jVlbygDZTU+JJjosJeEK20+nikW1H6e7tG+vuCiFMSgJ9CHlX3rR09bL+kkK/bZRSFGdbA15Lv2lfLT/512Ge21MzZv0UQpibBPoQ8laxLM6ycmlx5qDtZmcnU9EQ2Ih+Z2UzABtLT6C1Hn0nhRCmJ4E+hGakJ7JsxhS+8oE5Q84ZFGcnY+/s6T/vdjBaa94+1ow1LobyunbeOdk61l0WQpiQBPoQiomO4rnPXcw186cO2c47ITtc+qaqqZMGh5P71hSTZIlm484TY9ZXIYR5SaA3AW/Nm4phJmTfNtI2a0qyuHlJHi/vr5XKl0IICfRmkJUcR0p8zLAj+p2VdrJT4ijMTGLdyhk4XW6e33tqgnophAhXEuhNQCnlmZAdYkTvzc+vKspAKcW83BQunJ7GxtJqmZQVYpKTQG8SxdnJHGlwDBq0jzV20tTh5CKf3bPrVs6gsrGTnZX2ieqmECIMSaA3idnZVlq7emkcZOWNNz/vWybhQwtzSE2I5enS6gnpoxAiPEmgN4nhJmR3VjYzNSWeGRmJ/dfiY6O5bWk+/zp4mkbH0EszhRCRSwK9SRQPscRSa01pZTMXzcwYsB7/4yun43Jrni07OSH9FEKEHwn0JmGzxpGWGOu35s3Rhg6aOnpYVTTwkPSZNisXFWXwx10n6HPLpKwQk1HAgV4pFa2Uekcp9bLxvFApVaqUqlBK/VkpZTGuxxnPjxqvF4xP1ycXpRSzs5L9VrH0lj24qMh/GYV1q6ZT03KG1ysax7WPQojwNJIR/X1Auc/zHwMPaa2LgRZgvXF9PdCitZ4FPGS0E2PAW9zs/JU3b1c2k5saz7T0BL/v+8C8qWRaLbJTVohJKqBAr5TKBz4EPGY8V8BVwHNGk6eAm43HNxnPMV5fo8Kh+HsEmJ2dTHu3iwafiVWtNTsr7azyk5/3ssRE8eFl09j6fj21rWf8thFCRK5AR/QPA18H3MbzDKBVa+0yntcAecbjPOAkgPF6m9FejJK/Cdkj9R3YO3uGPX3qYyumo4E/7ZZJWSEmm2EDvVLqeqBBa73H97KfpjqA13w/926lVJlSqqyxUXLHgZjdf6zg2QnZs/n5oQP9tPRELp9t40+7PKdZCSEmj0BG9KuBG5VSx4E/4UnZPAykKaW8Z87mA7XG4xpgGoDxeiowYGum1nqD1nqZ1nqZzWYb1Q8xWWRa40hPspwzIfv2sWby0hKYlp44xDs97lpdSIPDyaPbj41nN4UQYWbYQK+1/qbWOl9rXQB8FNiqtV4HbANuM5rdCbxoPN5kPMd4fauWYitjpjjr7GlTbremtMqzfj4Ql8+2cdPiXH65tYJDte3j2U0hRBgZzTr6bwD3K6WO4snBP25cfxzIMK7fDzwwui4KX97iZlprDtc7aOnqHTY/7+s7N1xAWqKFr/5lH719ksIRYjIYUaDXWm/XWl9vPK7UWq/QWs/SWt+utXYa17uN57OM1yvHo+OT1exsKw6ni9Pt3f35eX8bpQYzJcnCD26Zz6G6dh7ZdnS8uimECCOyM9Zkin0mZN8+1sy09ATypwyfn/f1gQumcvPiXH619Sjv1bYN276tq5fKxsDOrBVChB8J9CbjXXlz+HQ7pVX2YVfbDOY7N17AlCQLX3l235CrcN461sTVD73GTb96E7eUUBDClCTQm0x6koVMq4VN+2ppOzOy/LyvtEQLP7hlAe+fdvArPymcPrfmwVeOsO6xUlq7enE4XcMeTi6ECE8S6E2oOCuZg6c8q2aCDfQAV8/L5pYlefx621EOnjqbwjnd1s3Hf7uTX2yp4NYl+fzk9oUA1LZ1j67jQoiQkEBvQrONHbIzMhLJTfNf3yZQ375hHulJnlU4PS43295v4LpfvMGBU2387PZF/OzDi5iV5fl+Uj5BCHOKGb6JCDfeCdlVhaOvLJGWaOGHty5g/VNl3P6/b7PvZCtzpybzyLoLmWnzBPjcVM8vEwn0QpiTjOhNqCQnBYCLZ41NCaE1JdncemEe+0628slVM/jbvav7gzxAWmIsCbHR1EnqRghTkhG9CV04PY0/rF/BxTP9158Pxo9uXchnL5vJnKnJA15TSpGTFi8jeiFMSgK9CSmluLR4bOsDWWKi/AZ5r7y0BJmMFcKkJHUjApKTGk+djOiFMCUJ9CIgOakJNHY4pcSxECYkgV4EJC8tAa2hvl3SN0KYjQR6EZCctHgATkn6RgjTkUAvApJjrKWva5NAL4TZSKAXAck1RvS1rZK6EcJsJNCLgCRaYkhLjJW19EKYkAR6EbDc1ATZHSuECUmgFwHLld2xQpiSBHoRsJzUBAn0QpiQBHoRsNy0BNq7XXQ6XaHuihBiBCTQi4B5V97IEkshzEUCvQiYdy39KVliKYSpSKAXAesf0UueXghTkUAvApadEo9SctKUEGYjgV4ELDY6iuzkeKlLL4TJSKAXI5KTFi+TsUKYjAR6MSK5qQlS70YIk5FAL0bEuztWax3qrpxDa833XjrE/prWUHdFiLAjgV6MSE5qAk6Xm5au3lB35RxHGzp44s0qfr3tWKi7IkTYkUAvRuRsueLwytPvrLIDsO1wAx2yc1eIc0igFyOSm+bZNBVugX5XlZ3YaIXT5WZLeX2ouyNEWJFAL0bk7ElT4TMhq7VmV1Uz18zPYWpKPC/tqwt1l4QIKzGh7oAwl4wkC5aYqLAa0Vc3d1Hf7mRVUTo2axxP76ym7UwvqQmxoe6aEGFBRvRiRKKiFDmp4bVpapeRn19ZmM71i3Lo6XPzyiFJ3wjhJYFejFhOanxY1bsprbKTkWRhps3Kkmlp5KUl8Pf9taHulhBhQwK9GLHctPA6gKS0qpkVhekopVBKcf3CHN6oaKK1q2dUn9vS2cP3/35I6u8L05NAL0YsNzWBeocTV5871F3hVOsZalrOsKIwvf/ahxbm4HJr/vXe6VF99m/fqOS3b1Sx+YBM7gpzk0AvRiwnLZ4+t6bB4Qx1V9ht5Od9A/2CvFSmpyfy8v7gA/SZnj6e2XUCgC3lDaPrpBAhNmygV0pNU0ptU0qVK6XeU0rdZ1xPV0q9opSqML5OMa4rpdQvlFJHlVL7lVIXjvcPISaWdy19OBQ3K62ykxIfw9ypKf3XvOmbt44109wR3C+j5/fW0NrVy/y8FF6vaKS7t2+suizEhAtkRO8CvqK1LgFWAfcqpeYBDwBbtNbFwBbjOcC1QLHx527g0THvtQip3FTvpqnQr7wprWpmeUE60VHqnOvXL8ylz635ZxDpG7db88SbVSzIS+UrV8+hq6ePnZXNY9VlISbcsIFea12ntd5rPHYA5UAecBPwlNHsKeBm4/FNwO+1x04gTSmVM+Y9FyETLmUQGh1OKhs7z0nbeJXkJFNkS+LlIDZPvVbRSGVjJ+svKeSimRkkxEZL+kaY2ohy9EqpAmAJUApka63rwPPLAMgymuUBJ33eVmNcExEiOT6W5LiYkO+O3X18YH7eSynF9QtyKK1qpsExsn4+saOKrOQ4rluQQ3xsNJcWZ7KlvD7sKnYKEaiAA71Sygo8D3xJa90+VFM/1wb8DVFK3a2UKlNKlTU2NgbaDREmcoxyxaFUWtlMoiWa+Xmpfl+/flEubg3/OBB4+ubwaQdvVDRx58UFWGI8fz3WlmRT29ZNeZ1jTPotxEQLKNArpWLxBPmNWusXjMv13pSM8dX7b9saYJrP2/OBAbtXtNYbtNbLtNbLbDZbsP0XIZKblkBtiCdjS6vsLJ0xhdho//8bz85OZna2lb+PYPXNEzuqiI+N4uMrpvdfu3JuFkohxdKEaQWy6kYBjwPlWusHfV7aBNxpPL4TeNHn+h3G6ptVQJs3xSMiR05qAnUhnIxt7erhcL2DFQUD0za+rl+Yy+5qO6cDSDM1dTj567unuPXCfKYkWfqv25LjWJSfxqvvS55emFMgI/rVwCeBq5RS7xp/rgN+BFytlKoArjaeA2wGKoGjwG+Bz499t0Wo5abG09zZE7Jlh7uPt6A1rCzKGLLdhxbmoDX8PYBNT8+UnqDH5ebTqwsHvLa2JIt9J1tHnO8XIhwEsupmh9Zaaa0Xaq0XG382a62btdZrtNbFxle70V5rre/VWs/UWi/QWpeN/48hJtrZtfShCXy7qpqxxESxMN9/ft5rps1KSU4KLw9T+8bp6uP3b1dzxRwbs7KsA15fU5INwDYZ1QsTkp2xIig5xhLLoYqb7a9p5a2jTePy/XdV2Vk8LY342Ohh216/MId3TrRy0t41aJuX9tXR1OH0O5oHmDs1mby0BF6VZZbChCTQi6DkGSP6U4ME+jM9fax/qoy7ntw9ZIANRofTxcHadlb6WVbpz42LcomNVlz/yx08/OoR2s4771ZrzeM7qpidbeXS4ky/n6GUYk1JFjsqmmSXrDAdCfQiKFNTjRH9IKmbp3dW0+hw4taaH2wuH9Pvvae6hT63ZmXh0Pl5r2npifz186tZUZjOw69WcMmPt/KTf72PvdNT3XJnpZ3yunY+vboQz9oD/9aUZHOmt4+3j8kuWWEuEuhFUOJiosm0WvzWu+l0unj0tWNcWpzJF68q5h8HT/PWsbFL4eyqaiYmSnHhjLSA3zM/L5Xf3rGMzV+8lMtm2/j19mOs/tFWvv/3Q/x6+1HSkyzcvGTofX2ritJJskTzyjDLLP/6Tg23PfoWTpeM/EV4kEAvgpablsApP0ssn3zrOPbOHu6/ejb/cVkR+VMS+O6mQ2NW1ri00s78vFQSLSM/CXNebgqPrLuQV758GdfMn8rjO6p4o6KJT6ycPmy+Py4mmkuLbWwtbxh0l+zOyma+9pf9lFW3sO9k24j7J8R4kEAvgubvpKn27l42vF7JVXOzWDJ9CvGx0fyfD5VwuN7RX/Z3NLp7+9hX08rKosDy84OZlZXMQx9ZzNavXME3rpnLZy4rCuh9a0qyON3ezXu1AzeHH2/q5J6n95A/JQGlPDt3hQgHEuhF0HJSPSdN+Y5uH3+jirYzvdx/9ez+ax+8YCqrZ2Xws38foaVzdKc+vXOild4+HfBE7HAKMpP43BUzSYkP7CBx7y7ZV89L37Sd6WX9U7sBePKuFczJTmaXUYtHiFCTQC+ClpeWQGdPH+3dnqP2Wjp7eGJHFddcMPWc+jNKKb59wwV0OF08+MqRUX3P0qpmlIKlM8Ym0I9UpjWOJdPSzqlm6epz84Vn9nLC3sVvPrGUgswkVhams6e6hd4wOIVLCAn0Imj9a+mNCdkNb1TS0ePiyz6jea/Z2cl8ctUMNpZWc8hP2iNQu6rszMtJITUhsBH4eFhTks2BU23Ut3vmJ7770iHeqGji+7csYJWxU3dlUQZdPX0cPCV5ehF6EuhF0HKMA0jqWrtp6nDy5JvHuX5hLnOmJvtt/+W1s0lNiOW7L70XVMnft442savKzsUzA1tWOV7WGrtkt5Q38NRbx/nDzmo+e1kRH152tpbfcqMGT2mVpG9E6EmgF0Hz3TT1m+3HcLr6+NLa4kHbpybG8pUPzKG0ys7mEZQOBjh5vJY/AAAQd0lEQVTa0ME9T++hyJbEf64Z/HtMhNnZVqalJ7Dh9WN896X3WFuSzdevmXtOG1tyHDNtSeySQC/CgAR6ETRbchwxUYp3T7byh53V3Lwkj5m2gXVifH1sxXRKclL4weZyzvQEts68ucPJXU/uwhITxeN3Lg944nS8KKVYMzeb481dzJmaws8/unjAUYYAKwoz2H3cTp87/A4s0Vrzuaf38Ozuk8M3FqYngV4ELTpKkZ0Sz/N7a3C5NfcFMNKOjlJ898YLONV6hh9sLh92srK7t4+7/7CHhnYnv71jGdPSE8eq+6PyiVXTWVuSzWN3LiMpzv96/pWF6Ti6XZTXBT8nMV6ONnTwj4On+fE/3w/4F64wLwn0YlRy0+LRGj68LJ8ZGUkBvWdFYTp3XDSDP+ys5qZfvcmBGv8Tlm635mvP7WdPdQsPfWQxS6ZPGcuuj8qsrGQeu3NZf/rKH+8Rh+GYvvEWZ2vu7OGPY7C/QYQ3CfRiVPKnJGKJjuILV40sb/69m+bzm08spanDyU2P7PCbynno1SO8tK+Wb1wzl+sWmO98+dy0BKalJ1BaFX4bp7a+X88FuSmsKExnw+uVEVeu4e1jzbz47qlQdyNsSKAXo3LfmmKevGv5kCPbwVwzfyqv3H85H1k+jQ2vV/LBh1/nTaOs8XN7avjl1qN8ZNk07rk8sF2r4WhlYQa7quxhdbB4S2cPe6pbWFOSzX9eNYvT7d28sDdyguIfdlaz7rGd3P/sPpo7nKHuTliQQC9GpSAziYtn+S/tG4jUhFh+eOtC/vgfq4hSsO6xUu75wx6++cJ+Vs/K4H9umT9kRclwt6IwnZauXioaOkLdlX7bDjfg1rBmbhaXzMpkUX4qj24/Nma1iEKlz63575cP8X//dpDF09Loc2s2HxzZ6q5IJYFehIWLZmbwzy9dxueumMkr5fXMyEji1+uWDnrwt1l4SzWE03r6LeUN2JLjWJCXilKKe6+cxQl7Fy8NcwpXOOvqcfHZP+zh8R1VfOriAv5yz8UUZ1l56V3z/kxjydx/i0REiY+N5hvXzGXbV67g+XsuDunu17EyPT2RqSnxYTMh2+Ny8/qRRtbMzSLKWBK6tiSbuVOT+fW2Y7jDcCnocOrbu/nw/77N1vfr+c4N8/jOjRcQHaW4cVEuu47bqR3iFLTJQgK9CDvTMxJJTTR/kAfPmvsVhemUVjaHRZ5+93E7DqeLq+Zm9V+LilJ8/spZVDR08O9D5kp1lNe1c/Mjb1LZ2Mljdy7jUz5HQd6wKBdg2POCJwMJ9EKMs5VF6TQ4nFQ3j+2RisF4tbweS0wUl5x3ZOKHFuRQmJnEL7ceDYtfSIEoO27ntkffQmv4yz0XcdXc7HNeL8hMYlF+Kpv2SaCXQC/EODubpw/tMkutNVvKG1g9M2PAoS3RUYrPXT6T92rb2X6kMUQ9DJzWmm9veo8pSRb+du9qLshN9dvuhkW5HDzVzrHG8JkMDwUJ9EKMs5k2KxlJlpBPyB5r7OCEvYs1Jdl+X795SR55aQn8ygSj+lfLG3ivtp371hT3n1/szw2LclEKNk3ySVkJ9EKMs7N5+tAGeu9uWN/8vC9LTBSfvbyIPdUt7AxxX4eitebhV48wIyORW4Y55zc7JZ5VhRm8tK827H95jScJ9EJMgBWF6ZxqPUNNS+jy9FvLG5iXk0LuEJvbPrxsGpnWOB7ZdnQCezYyW4zR/L1XziImgOW3Ny7OpbKp0+/xj5OFBHohJsDKQk8N/d0hOl6wpbOHsmo7a0v8j+a94mOjufuyQnYcbWL74YYh24aC1pqfb6lgWnrCsKN5r2vnTyU2Wk3qSVkJ9EJMgDlTk0mJjwlZ+mb7Ec9u2KsGyc/7WrdyBnlpCXzqd7u5d+Nejjd1TkAPA7PtcAMHTrXxhStnBbyZLi3RwmXFNl7aVztu+wSCSQu53ZpfbqmgaQLKNEigF2ICREd58vSh2jjl3Q27MM//6hRfSXEx/OvLl3HfmmK2vt/A2gdf49svHpyQgDQUrTU/f7WC/CkJ3Hph/ojee+PiXOrauimrbhnzfjV3OPnAQ69z35/eCbjks9PVxxf/9A4/e8VTuG+8SaAXYoKsKEynsqmTBuOs2YnS2+fmtSONXDXn7G7Y4VjjYvjy1bN57WtX8OHl03i69ARX/GQ7v9paQVePa5x77N/2I43sqxnZaN5rbUk28bFRbNo3tsXbelxuPvf0Xqqbu9i0r5aPbnh72P++HU4X658s4+X9dXzz2rnc5bPJa7xIoBdigqww8vS7JjhPv7vKjqPbxZph8vP+ZKXE84NbFvCvL13GxTMz+Om/j3DFT7azsbR62ENjAnWmp4/Dpx1Dpj88K20qyEsb+WgePP9KuXreVDYfOD1m/dZa8/9ePMiu43Z+cvtCNnxyGRUNHdz8yJscGmTit6nDycc27OTtymZ+evsiPnv5zDHpy3Ak0AsxQebnppBoifabp+/tc3O6rRtHd++Yf98t7zf43Q07ErOyrGy4Yxl/ueci8qck8F9/PcjVD77Gi++eCjrvXVHv4Dub3mPFD17lgw+/zvqnyqhr81+X5rUjjew72cq9V87CEhNc2LpxUS72zh52GKWwR+vJt47zp90nuffKmdy0OI+r52Xzl3suQgO3/eYtXj1Uf077k/Yubnv0LSoaHPz2jqXctnTkv7CCpcJhbemyZct0WVlZqLshxLj75OOlvH/awaqiDJocTpo6PH9aujwB3hoXw/duuoBbluSNSXlmrTVX/nQ7BZlJPHnXilF/nvczt5Q38NN/H+b90w5KclL4+gfncMUc27B9drr6+OfB02wsPcGuKjux0Ypr5+dQZEviN68dIzYqiv/6UAkfWT6t/7O01tz66Fs0tDvZ9tUrgg70Tlcfy//nVdaWZPPgRxYH9Rleb1Q0cucTu1hTks3/fmLpOSmxhvZuPvP7Mg6cauNb15bwmUsLKa9zcOfvdtHjcvPEp5azdMbYnJamlNqjtV42XDv/h10KIcbFDYty2V9TzoGaVjKtccy0WVlZlE6mNY4Maxyb3j3F/c/uY/vhRv7nlvmjPgj9WGMnx5u7WH/p2B3eopRi7bxsrpybxUv7annwlSPc9eRulhdM4ctXzyY7JZ6ObhedThcOp8vzuMdFdXMXf33nFPbOHmZkJPLAtXO5bWk+mdY4AG5enMc3nt/PAy8c4OX9dfzw1gVMS0/kjYom3jnRyvdvmR90kAeIi4nm2vk5vLy/lu7ePuJjo4P6nMrGDu7duJfirGQe+sjiAfMeWSnx/Pnui7j/2Xf5/uZy9p5oYUdFE9b4GJ655yKKs5OD/hmCJSN6IcKIq8/Nr7cf4+dbKpiaEs/DH13M8oL0oD9vw+vH+MHm93nzgauCOgUsED0uN3/efYJfbD1Ko2PwlTnRUYqrS7JZt2o6q2dm+p0Ydrs1z+w6wQ83l6OBb1wzl037aqlrPcP2r105qkAP8ObRJtY9Vsqv110Y1PGU7d293PzIm7R29fLivauHPKze7db87JXDPLLtGLOyrPz+0yuG3KwWjEBH9BLohQhDe0+08KU/vUtNSxdfuHIWX1xTHNAuUF9VTZ38x+/LiI2O4h/3XTpOPT2rq8fFK0ZeOskSgzU+BmtcDMnxMSQZX+NiAhtFn2o9wzdfOMDrRoG1/755Pp9cNWPUfexza1b9cAtJlmiWzkgnNSGW1IRYUhJi+h9PSbJgs8ZhS447Z9Tf59Z8+sndvHm0iac/s5JVRRkBfc99J1sptCWN+l9n/kigF8LkHN29fHvTe7yw9xRLpqfxs9sXUWSzDvu+9u5efrmlgiffOo4lOoqffXgx18yfOgE9Hltaa57bU0NplZ3v3zI/4F8Sw9lYWs3TO0/QfqaXtjO9dDgHXy6aHB+DLTkOmzWOPremrLqF798yn3UrR/9LZyxIoBciQmzaV8t//fUAjm4Xy2ZM4YZFuVy3IAdbctw57frcmj/vPsnP/n0Ye1cPty/N56sfnENW8uDVHYUnXdbe7aLNCPwtnT00Opw0djg9X40/TR1ObliUy5evnh3qLveTQC9EBDnd1s3ze2t4aV8t7592EKXg4pmZ3LAoh2suyOFQXTvfe/kQ5XXtrChI5//dMI/5AeyCFeYW0kCvlLoG+DkQDTymtf7RUO0l0AsRuCP1Dl7aV8tL+2o53txFTJTC5dbkpSXwretKuG7B1DFZminCX8gCvVIqGjgCXA3UALuBj2mtDw32Hgn0Qoyc1pqDp9r5+4E60pNiueOigqCXDApzCuU6+hXAUa11pdGRPwE3AYMGeiHEyCmlWJCfyoJ8SdGIoY1HCYQ84KTP8xrj2jmUUncrpcqUUmWNjeF/RqUQQpjVeAR6f8nBAfkhrfUGrfUyrfUym802Dt0QQggB4xPoa4BpPs/zgcl7tIsQQoTYeAT63UCxUqpQKWUBPgpsGofvI4QQIgBjPhmrtXYppb4A/AvP8sontNbvjfX3EUIIEZhxqV6ptd4MbB6PzxZCCDEycvCIEEJEOAn0QggR4cKi1o1SqhGoDvLtmcDYnA0WOeSe+Cf3ZSC5JwOZ6Z7M0FoPuz49LAL9aCilygLZAjyZyD3xT+7LQHJPBorEeyKpGyGEiHAS6IUQIsJFQqDfEOoOhCG5J/7JfRlI7slAEXdPTJ+jF0IIMbRIGNELIYQYgqkDvVLqGqXUYaXUUaXUA6HuTygopZ5QSjUopQ76XEtXSr2ilKowvk4JZR8nmlJqmlJqm1KqXCn1nlLqPuP6pL0vSql4pdQupdQ+455817heqJQqNe7Jn436VJOKUipaKfWOUupl43nE3RPTBnrjJKtHgGuBecDHlFLzQturkHgSuOa8aw8AW7TWxcAW4/lk4gK+orUuAVYB9xr/b0zm++IErtJaLwIWA9copVYBPwYeMu5JC7A+hH0MlfuAcp/nEXdPTBvo8TnJSmvdA3hPsppUtNavA/bzLt8EPGU8fgq4eUI7FWJa6zqt9V7jsQPPX+I8JvF90R4dxtNY448GrgKeM65PqnsCoJTKBz4EPGY8V0TgPTFzoA/oJKtJKltrXQeeoAdkhbg/IaOUKgCWAKVM8vtipCjeBRqAV4BjQKvW2mU0mYx/hx4Gvg64jecZROA9MXOgD+gkKzF5KaWswPPAl7TW7aHuT6hprfu01ovxHAa0Aijx12xiexU6SqnrgQat9R7fy36amv6ejEuZ4gkiJ1kNrl4plaO1rlNK5eAZwU0qSqlYPEF+o9b6BePypL8vAFrrVqXUdjzzF2lKqRhjBDvZ/g6tBm5USl0HxAMpeEb4EXdPzDyil5OsBrcJuNN4fCfwYgj7MuGMPOvjQLnW+kGflybtfVFK2ZRSacbjBGAtnrmLbcBtRrNJdU+01t/UWudrrQvwxI+tWut1ROA9MfWGKeM38cOcPcnq+yHu0oRTSv0RuAJPxb164NvA34BngenACeB2rfX5E7YRSyl1CfAGcICzuddv4cnTT8r7opRaiGdiMRrPAO9ZrfX3lFJFeBYypAPvAJ/QWjtD19PQUEpdAXxVa319JN4TUwd6IYQQwzNz6kYIIUQAJNALIUSEk0AvhBARTgK9EEJEOAn0QggR4STQCyFEhJNAL4QQEU4CvRBCRLj/D22ZJ5wkOG7AAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(train_error)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict: \n",
      "  [[13.75126667]\n",
      " [16.01682834]\n",
      " [15.82339291]\n",
      " [15.43284007]\n",
      " [14.56283577]\n",
      " [13.01965896]\n",
      " [14.97113148]\n",
      " [12.41158278]\n",
      " [15.02819588]\n",
      " [13.39835525]]\n",
      "label: \n",
      " [[ 8.5]\n",
      " [ 5. ]\n",
      " [11.9]\n",
      " [27.9]\n",
      " [17.2]\n",
      " [27.5]\n",
      " [15. ]\n",
      " [17.2]\n",
      " [17.9]\n",
      " [16.3]]\n",
      "mse:  47.98992180279209\n",
      "rmse:  6.927475860859573\n",
      "mae:  5.3624039225624385\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "dnn.eval(data_sample_test,data_label_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DNNNet:\n",
    "    def __init__(self, optimizer = Optimizer.batch_gradient_descent_fixed, initializer = Initializer.xavier, batch_size=16, weights_decay=0.001):\n",
    "        self.optimizer = optimizer\n",
    "        self.initializer = initializer\n",
    "        self.batch_size = batch_size\n",
    "        self.weights_decay = weights_decay\n",
    "        self.fc1 = FullyConnectedlayer(13,16,self.batch_size, self.weights_decay)\n",
    "        self.ac1 = ActivationLayer('relu')\n",
    "        self.fc2 = FullyConnectedlayer(16,1,self.batch_size, self.weights_decay)\n",
    "        self.loss = Losslayer(\"LeastSquareLoss\")\n",
    "\n",
    "    def forward_train(self,input_data, input_label):\n",
    "        self.fc1.get_inputs_for_forward(input_data)\n",
    "        self.fc1.forward()\n",
    "        self.ac1.get_inputs_for_forward(self.fc1.outputs)\n",
    "        self.ac1.forward()\n",
    "        self.fc2.get_inputs_for_forward(self.ac1.outputs)\n",
    "        self.fc2.forward()\n",
    "\n",
    "        print(\"predict: \\n \",self.fc2.outputs)\n",
    "        print(\"label: \\n\", input_label)\n",
    "        self.loss.get_inputs_for_loss(self.fc2.outputs)\n",
    "        self.loss.get_label_for_loss(input_label)\n",
    "        self.loss.compute_loss()\n",
    "        print(\"loss: \",self.loss.loss)\n",
    "\n",
    "\n",
    "    def backward_train(self):\n",
    "        self.loss.compute_gradient()\n",
    "        self.fc2.get_inputs_for_backward(self.loss.grad_inputs)\n",
    "        self.fc2.backward()\n",
    "        self.ac1.get_inputs_for_backward(self.fc2.grad_inputs)\n",
    "        self.ac1.backward()\n",
    "        self.fc1.get_inputs_for_backward(self.ac1.grad_inputs)\n",
    "        self.fc1.backward()\n",
    "\n",
    "    def predict(self,input_data):\n",
    "        self.fc1.get_inputs_for_forward(input_data)\n",
    "        self.fc1.forward()\n",
    "        self.ac1.get_inputs_for_forward(self.fc1.outputs)\n",
    "        self.ac1.forward()\n",
    "\n",
    "        self.fc2.get_inputs_for_forward(self.ac1.outputs)\n",
    "        self.fc2.forward()\n",
    "        return self.fc2.outputs\n",
    "\n",
    "    def eval(self,input_data, input_label):\n",
    "        self.fc1.update_batch_size(input_data.shape[0])\n",
    "        self.fc1.get_inputs_for_forward(input_data)\n",
    "        self.fc1.forward()\n",
    "        self.ac1.get_inputs_for_forward(self.fc1.outputs)\n",
    "        self.ac1.forward()\n",
    "        self.fc2.update_batch_size(input_data.shape[0])\n",
    "        self.fc2.get_inputs_for_forward(self.ac1.outputs)\n",
    "        self.fc2.forward()\n",
    "        print(\"predict: \\n \",self.fc2.outputs[:10])\n",
    "        print(\"label: \\n\", input_label[:10])\n",
    "        metric = MetricCalculator(label=input_label, predict=self.fc2.outputs)\n",
    "        metric.get_mae()\n",
    "        metric.get_mse()\n",
    "        metric.get_rmse()\n",
    "        metric.print_metrics()\n",
    "\n",
    "    def update(self):\n",
    "        self.fc1.update(self.optimizer)\n",
    "        self.fc2.update(self.optimizer)\n",
    "\n",
    "    def initial(self):\n",
    "        self.fc1.initialize_weights(self.initializer)\n",
    "        self.fc2.initialize_weights(self.initializer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataHander:\n",
    "    def __init__(self,batch_size):\n",
    "        self.data_sample = 0\n",
    "        self.data_label = 0\n",
    "        self.output_sample = 0\n",
    "        self.output_label = 0\n",
    "        self.point = 0  # 用于记住下一次pull数据的地方;\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    def get_data(self, sample, label):  # sample 每一行表示一个样本数据, label的每一行表示一个样本的标签.\n",
    "        self.data_sample = sample\n",
    "        self.data_label = label\n",
    "\n",
    "    def shuffle(self):  # 用于打乱顺序;\n",
    "        random_sequence = random.sample(range(self.data_sample.shape[0]), self.data_sample.shape[0])\n",
    "        self.data_sample = self.data_sample[random_sequence]\n",
    "        self.data_label = self.data_label[random_sequence]\n",
    "\n",
    "    def pull_data(self):  # 把数据推向输出\n",
    "        start = self.point\n",
    "        end = start + self.batch_size\n",
    "        output_index = np.arange(start, end)\n",
    "        if end > self.data_sample.shape[0]:\n",
    "            end = end - self.data_sample.shape[0]\n",
    "            output_index = np.append(np.arange(start, self.data_sample.shape[0]), np.arange(0, end))\n",
    "        self.output_sample = self.data_sample[output_index]\n",
    "        self.output_label = self.data_label[output_index]\n",
    "        self.point = end % self.data_sample.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Optimizer:\n",
    "    def __init__(self, lr = 0.01, momentum = 0.9, iteration = -1, gamma=0.0005, power=0.75):\n",
    "        self.lr = lr\n",
    "        self.momentum = momentum\n",
    "        self.iteration = iteration\n",
    "        self.gamma = gamma\n",
    "        self.power = power\n",
    "    # 固定方法\n",
    "    def fixed(self):\n",
    "        return self.lr\n",
    "\n",
    "    # inv方法\n",
    "    def anneling(self):\n",
    "        if self.iteration == -1:\n",
    "            assert False, '需要在训练过程中,改变update_method 模块里的 iteration 的值'\n",
    "        self.lr = self.lr * np.power((1 + self.gamma * self.iteration), -self.power)\n",
    "        return self.lr\n",
    "\n",
    "    # 基于批量的随机梯度下降法\n",
    "    def batch_gradient_descent_fixed(self, weights, grad_weights, previous_direction):\n",
    "        direction = self.momentum * previous_direction + self.lr * grad_weights\n",
    "        weights_now = weights - direction\n",
    "        return (weights_now, direction)\n",
    "\n",
    "    def batch_gradient_descent_anneling(self, weights, grad_weights, previous_direction):\n",
    "        self.lr = self.anneling()\n",
    "        direction = self.momentum * previous_direction + self.lr * grad_weights\n",
    "        weights_now = weights - direction\n",
    "        return (weights_now, direction)\n",
    "\n",
    "    def update_iteration(self, iteration):\n",
    "        self.iteration = iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Initializer:\n",
    "    # xavier 初始化方法\n",
    "    def xavier(self, num_neuron_inputs, num_neuron_outputs):\n",
    "        temp1 = np.sqrt(6) / np.sqrt(num_neuron_inputs + num_neuron_outputs + 1)\n",
    "        weights = stats.uniform.rvs(-temp1, 2 * temp1, (num_neuron_inputs, num_neuron_outputs))\n",
    "        return weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ActivationFunction:\n",
    "    def __init__(self):\n",
    "        pass\n",
    "    # sigmoid函数及其导数的定义\n",
    "    def sigmoid(self, x):\n",
    "        return 1 / (1 + np.exp(-x))\n",
    "\n",
    "    def der_sigmoid(self, x):\n",
    "        return self.sigmoid(x) * (1 - self.sigmoid(x))\n",
    "\n",
    "    # tanh函数及其导数的定义\n",
    "    def tanh(self, x):\n",
    "        return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))\n",
    "\n",
    "    def der_tanh(self, x):\n",
    "        return 1 - self.tanh(x) * self.tanh(x)\n",
    "\n",
    "    # ReLU函数及其导数的定义\n",
    "    def relu(self, x):\n",
    "        temp = np.zeros_like(x)\n",
    "        if_bigger_zero = (x > temp)\n",
    "        return x * if_bigger_zero\n",
    "\n",
    "    def der_relu(self, x):\n",
    "        temp = np.zeros_like(x)\n",
    "        if_bigger_equal_zero = (x >= temp)  # 在零处的导数设为1\n",
    "        return if_bigger_equal_zero * np.ones_like(x)\n",
    "\n",
    "    # Identity函数及其导数的定义\n",
    "    def identity(self, x):\n",
    "        return x\n",
    "\n",
    "    def der_identity(self, x):\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ActivationLayer:\n",
    "    def __init__(self, activation_function_name):\n",
    "        self.actfunc = ActivationFunction()\n",
    "        if activation_function_name == 'sigmoid':\n",
    "            self.activation_function = self.actfunc.sigmoid\n",
    "            self.der_activation_function = self.actfunc.der_sigmoid\n",
    "        elif activation_function_name == 'tanh':\n",
    "            self.activation_function = self.actfunc.tanh\n",
    "            self.der_activation_function = self.actfunc.der_tanh\n",
    "        elif activation_function_name == 'relu':\n",
    "            self.activation_function = self.actfunc.relu\n",
    "            self.der_activation_function = self.actfunc.der_relu\n",
    "        elif activation_function_name == 'linear':\n",
    "            self.activation_function = self.actfunc.identity\n",
    "            self.der_activation_function = self.actfunc.der_identity\n",
    "        else:\n",
    "            print('wrong activation function')\n",
    "        self.inputs = 0\n",
    "        self.outputs = 0\n",
    "        self.grad_inputs = 0\n",
    "        self.grad_outputs = 0\n",
    "\n",
    "    def get_inputs_for_forward(self, inputs):\n",
    "        self.inputs = inputs\n",
    "\n",
    "    def forward(self):\n",
    "        self.outputs = self.activation_function(self.inputs)\n",
    "\n",
    "    def get_inputs_for_backward(self, grad_outputs):\n",
    "        self.grad_outputs = grad_outputs\n",
    "\n",
    "    def backward(self):\n",
    "        self.grad_inputs = self.grad_outputs * self.der_activation_function(self.inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LossFunction:\n",
    "    # SoftmaxWithLoss函数及其导数的定义\n",
    "    def softmax_logloss(self, inputs, label):\n",
    "        temp1 = np.exp(inputs)\n",
    "        probability = temp1 / (np.tile(np.sum(temp1, 1), (inputs.shape[1], 1))).T\n",
    "        temp3 = np.argmax(label, 1)  # 纵坐标\n",
    "        temp4 = [probability[i, j] for (i, j) in zip(np.arange(label.shape[0]), temp3)]\n",
    "        loss = -1 * np.mean(np.log(temp4))\n",
    "        return loss\n",
    "\n",
    "    def der_softmax_logloss(self, inputs, label):\n",
    "        temp1 = np.exp(inputs)\n",
    "        temp2 = np.sum(temp1, 1)  # 它得到的是一维的向量;\n",
    "        probability = temp1 / (np.tile(temp2, (inputs.shape[1], 1))).T\n",
    "        gradient = probability - label\n",
    "        return gradient\n",
    "\n",
    "    def least_square_loss(self, predict, label):\n",
    "        tmp1 = np.sum(np.square(label - predict), 1)\n",
    "        loss = np.mean(tmp1)\n",
    "        return loss\n",
    "\n",
    "    def der_least_square_loss(self, predict, label):\n",
    "        gradient = predict - label\n",
    "        return gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Losslayer:\n",
    "    def __init__(self, loss_function_name):\n",
    "        self.lossfunc = LossFunction()\n",
    "        self.inputs = 0\n",
    "        self.loss = 0\n",
    "        self.grad_inputs = 0\n",
    "        if loss_function_name == 'SoftmaxLogloss':\n",
    "            self.loss_function = self.lossfunc.softmax_logloss\n",
    "            self.der_loss_function = self.lossfunc.der_softmax_logloss\n",
    "        elif loss_function_name == 'LeastSquareLoss':\n",
    "            self.loss_function = self.lossfunc.least_square_loss\n",
    "            self.der_loss_function = self.lossfunc.der_least_square_loss\n",
    "        else:\n",
    "            print(\"wrong loss function\")\n",
    "    def get_label_for_loss(self, label):\n",
    "        self.label = label\n",
    "\n",
    "    def get_inputs_for_loss(self, inputs):\n",
    "        self.inputs = inputs\n",
    "\n",
    "    def compute_loss(self):\n",
    "        self.loss = self.loss_function(self.inputs, self.label)\n",
    "\n",
    "    def compute_gradient(self):\n",
    "        self.grad_inputs = self.der_loss_function(self.inputs, self.label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MetricCalculator:\n",
    "    def __init__(self, label, predict):\n",
    "        self.label = label\n",
    "        self.predict = predict\n",
    "        assert len(label)==len(predict), \"length of label and predict must be equal\"\n",
    "        self.mse = None\n",
    "        self.rmse = None\n",
    "        self.mae = None\n",
    "        self.auc = None\n",
    "\n",
    "    def get_mse(self):\n",
    "        self.mse = np.mean(np.sum(np.square(self.label - self.predict),1))\n",
    "\n",
    "    def get_rmse(self):\n",
    "        self.rmse = np.sqrt(np.mean(np.sum(np.square(self.label - self.predict), 1)))\n",
    "\n",
    "    def get_mae(self):\n",
    "        self.mae = np.mean(np.sum(np.abs(self.label - self.predict),1))\n",
    "\n",
    "    def get_auc(self):\n",
    "        prob = self.predict.reshape(-1).tolist()\n",
    "        label = self.label.reshape(-1).tolist()\n",
    "        f = list(zip(prob, label))\n",
    "        rank = [values2 for values1, values2 in sorted(f, key=lambda x: x[0])]\n",
    "        rankList = [i + 1 for i in range(len(rank)) if rank[i] == 1]\n",
    "        posNum = 0\n",
    "        negNum = 0\n",
    "        for i in range(len(label)):\n",
    "            if (label[i] == 1):\n",
    "                posNum += 1\n",
    "            else:\n",
    "                negNum += 1\n",
    "        self.auc = (sum(rankList) - (posNum * (posNum + 1)) / 2) / (posNum * negNum)\n",
    "\n",
    "    def print_metrics(self):\n",
    "        if(self.mse): print(\"mse: \",self.mse)\n",
    "        if(self.rmse): print(\"rmse: \",self.rmse)\n",
    "        if(self.mae): print(\"mae: \",self.mae)\n",
    "        if(self.auc): print(\"auc: \",self.auc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FullyConnectedlayer:\n",
    "    def __init__(self, num_neuron_inputs, num_neuron_outputs, batch_size=16,weights_decay=0.001):\n",
    "        self.num_neuron_inputs = num_neuron_inputs\n",
    "        self.num_neuron_outputs = num_neuron_outputs\n",
    "        self.inputs = np.zeros((batch_size, num_neuron_inputs))\n",
    "        self.outputs = np.zeros((batch_size, num_neuron_outputs))\n",
    "        self.weights = np.zeros((num_neuron_inputs, num_neuron_outputs))\n",
    "        self.bias = np.zeros(num_neuron_outputs)\n",
    "        self.weights_previous_direction = np.zeros((num_neuron_inputs, num_neuron_outputs))\n",
    "        self.bias_previous_direction = np.zeros(num_neuron_outputs)\n",
    "        self.grad_weights = np.zeros((batch_size, num_neuron_inputs, num_neuron_outputs))\n",
    "        self.grad_bias = np.zeros((batch_size, num_neuron_outputs))\n",
    "        self.grad_inputs = np.zeros((batch_size, num_neuron_inputs))\n",
    "        self.grad_outputs = np.zeros((batch_size, num_neuron_outputs))\n",
    "        self.batch_size = batch_size\n",
    "        self.weights_decay = weights_decay\n",
    "\n",
    "    def initialize_weights(self, initializer):\n",
    "        self.weights = initializer(self.num_neuron_inputs, self.num_neuron_outputs)\n",
    "\n",
    "    # 在正向传播过程中,用于获取输入;\n",
    "    def get_inputs_for_forward(self, inputs):\n",
    "        self.inputs = inputs\n",
    "\n",
    "    def forward(self):\n",
    "        self.outputs = self.inputs.dot(self.weights)+ np.tile(self.bias, (self.batch_size, 1))\n",
    "\n",
    "    # 在反向传播过程中,用于获取输入;\n",
    "    def get_inputs_for_backward(self, grad_outputs):\n",
    "        self.grad_outputs = grad_outputs\n",
    "\n",
    "    def backward(self):\n",
    "        # 求权值的梯度,求得的结果是一个三维的数组,因为有多个样本;\n",
    "        for i in np.arange(self.batch_size):\n",
    "            self.grad_weights[i, :] = np.tile(self.inputs[i, :], (1, 1)).T.dot(np.tile(self.grad_outputs[i, :], (1, 1))) + self.weights * self.weights_decay\n",
    "        # 求偏置的梯度;\n",
    "        self.grad_bias = self.grad_outputs\n",
    "        # 求输入的梯度;\n",
    "        self.grad_inputs = self.grad_outputs.dot(self.weights.T)\n",
    "\n",
    "    def update(self, optimizer):\n",
    "        # 权值与偏置的更新;\n",
    "        grad_weights_average = np.mean(self.grad_weights, 0)\n",
    "        grad_bias_average = np.mean(self.grad_bias, 0)\n",
    "        (self.weights, self.weights_previous_direction) = optimizer(self.weights, grad_weights_average,self.weights_previous_direction)\n",
    "        (self.bias, self.bias_previous_direction) = optimizer(self.bias,grad_bias_average, self.bias_previous_direction)\n",
    "\n",
    "    def update_batch_size(self,batch_size):\n",
    "        self.batch_size = batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:python35]",
   "language": "python",
   "name": "conda-env-python35-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
